Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ sample for ESRGAN #634

Merged
merged 47 commits into from
Aug 22, 2024
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9343cae
wip
jstoecker Aug 17, 2024
9661b87
wip
jstoecker Aug 17, 2024
714269e
wip
jstoecker Aug 17, 2024
8ffaf9e
asdf
jstoecker Aug 17, 2024
5933c34
asdf
jstoecker Aug 17, 2024
be7b0ff
asdf
jstoecker Aug 17, 2024
91f3bbd
asdf
jstoecker Aug 17, 2024
4ecef23
asdf
jstoecker Aug 17, 2024
bcb5f94
asdf
jstoecker Aug 18, 2024
8a8587d
asdf
jstoecker Aug 18, 2024
77e014a
asdf
jstoecker Aug 18, 2024
f5d6743
asdf
jstoecker Aug 18, 2024
b3c2bd1
asdf
jstoecker Aug 18, 2024
dd429dd
asdf
jstoecker Aug 18, 2024
4a57549
asdf
jstoecker Aug 18, 2024
bf94fc6
asdf
jstoecker Aug 18, 2024
0eb3a52
asdf
jstoecker Aug 20, 2024
5e5c232
adsf
jstoecker Aug 20, 2024
01420da
asdf
jstoecker Aug 20, 2024
531be51
asdf
jstoecker Aug 20, 2024
c03b983
asdf
jstoecker Aug 20, 2024
9ffda04
asdf
jstoecker Aug 20, 2024
21747fe
asdf
jstoecker Aug 20, 2024
3f2bb76
asdf
jstoecker Aug 20, 2024
22e6231
asdf
jstoecker Aug 20, 2024
32c400b
asdf
jstoecker Aug 20, 2024
8effdc4
asdf
jstoecker Aug 20, 2024
e3cf8b5
asdf
jstoecker Aug 20, 2024
0a4c5fc
asdf
jstoecker Aug 20, 2024
93daa1a
asdf
jstoecker Aug 20, 2024
418bf8c
asdf
jstoecker Aug 20, 2024
c1ba8e5
fix fp16 buffer sizes
jstoecker Aug 20, 2024
93c790b
remove null terminator from std string
jstoecker Aug 21, 2024
6f43ded
fix null terminator
jstoecker Aug 21, 2024
b3f5a98
hold ref to type info
jstoecker Aug 21, 2024
0eee4c1
use io bindings
jstoecker Aug 21, 2024
53ebdaa
shape tweaks
jstoecker Aug 21, 2024
9aecfc4
remove help cl opt
jstoecker Aug 21, 2024
0360860
download model
jstoecker Aug 21, 2024
cdd7c7d
3p notices
jstoecker Aug 21, 2024
12e6b3d
cmake note
jstoecker Aug 21, 2024
7475b0d
pch and separate translation unit for helpers
jstoecker Aug 21, 2024
9b6122f
cleanup
jstoecker Aug 21, 2024
b2bc7fa
cleanup
jstoecker Aug 21, 2024
261bbc4
cleanup
jstoecker Aug 21, 2024
f73e609
readme update
jstoecker Aug 21, 2024
e877e6d
copyright
jstoecker Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix fp16 buffer sizes
  • Loading branch information
jstoecker committed Aug 20, 2024
commit c1ba8e530326cbc8bf1bebe59f3ee789e1a0c46a
27 changes: 15 additions & 12 deletions Samples/DirectML_ESRGAN/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,20 @@ void RunModel(
throw std::invalid_argument("Model must have exactly one input and one output");
}

auto inputInfo = ortSession.GetInputTypeInfo(0);
auto inputTensorInfo = inputInfo.GetTensorTypeAndShapeInfo();
auto inputTensorShape = inputTensorInfo.GetShape();
auto inputTensorInfo = ortSession.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo();
auto inputShape = inputTensorInfo.GetShape();
auto inputDataType = inputTensorInfo.GetElementType();
if (inputDataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && inputDataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
{
throw std::invalid_argument("Model input must be of type float32 or float16");
}

const uint32_t inputChannels = inputTensorShape[1];
const uint32_t inputHeight = inputTensorShape[2];
const uint32_t inputWidth = inputTensorShape[3];
const uint32_t inputChannels = inputShape[1];
const uint32_t inputHeight = inputShape[2];
const uint32_t inputWidth = inputShape[3];
const uint32_t inputElementSize = inputDataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ? sizeof(float) : sizeof(uint16_t);

auto outputInfo = ortSession.GetOutputTypeInfo(0);
auto outputTensorInfo = outputInfo.GetTensorTypeAndShapeInfo();
auto outputTensorInfo = ortSession.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo();
auto outputTensorShape = outputTensorInfo.GetShape();
auto outputDataType = outputTensorInfo.GetElementType();
if (outputDataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT && outputDataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
Expand All @@ -82,9 +81,10 @@ void RunModel(
const uint32_t outputChannels = outputTensorShape[1];
const uint32_t outputHeight = outputTensorShape[2];
const uint32_t outputWidth = outputTensorShape[3];
const uint32_t outputElementSize = outputDataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ? sizeof(float) : sizeof(uint16_t);

// Load image and transform it into an NCHW tensor with the correct shape and data type.
std::vector<std::byte> inputBuffer(inputChannels * inputHeight * inputWidth * sizeof(float));
std::vector<std::byte> inputBuffer(inputChannels * inputHeight * inputWidth * inputElementSize);
FillNCHWBufferFromImageFilename(imagePath.wstring(), inputBuffer, inputHeight, inputWidth, inputDataType, ChannelOrder::RGB);

std::cout << "Saving cropped/scaled image to input.png" << std::endl;
Expand All @@ -96,8 +96,8 @@ void RunModel(
memoryInfo,
inputBuffer.data(),
inputBuffer.size(),
inputTensorShape.data(),
inputTensorShape.size(),
inputShape.data(),
inputShape.size(),
inputDataType
);

Expand All @@ -107,7 +107,10 @@ void RunModel(
std::vector<const char*> outputNames = { "output_0" };
auto outputs = ortSession.Run(runOpts, inputNames.data(), &inputTensor, 1, outputNames.data(), 1);

std::span<const std::byte> outputBuffer(reinterpret_cast<const std::byte*>(outputs[0].GetTensorData<float>()), outputChannels * outputHeight * outputWidth * sizeof(float));
std::span<const std::byte> outputBuffer(
reinterpret_cast<const std::byte*>(outputs[0].GetTensorRawData()),
outputChannels * outputHeight * outputWidth * outputElementSize
);

std::cout << "Saving inference results to output.png" << std::endl;
SaveNCHWBufferToImageFilename(
Expand Down