Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ sample for ESRGAN #634

Merged
merged 47 commits into from
Aug 22, 2024
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9343cae
wip
jstoecker Aug 17, 2024
9661b87
wip
jstoecker Aug 17, 2024
714269e
wip
jstoecker Aug 17, 2024
8ffaf9e
asdf
jstoecker Aug 17, 2024
5933c34
asdf
jstoecker Aug 17, 2024
be7b0ff
asdf
jstoecker Aug 17, 2024
91f3bbd
asdf
jstoecker Aug 17, 2024
4ecef23
asdf
jstoecker Aug 17, 2024
bcb5f94
asdf
jstoecker Aug 18, 2024
8a8587d
asdf
jstoecker Aug 18, 2024
77e014a
asdf
jstoecker Aug 18, 2024
f5d6743
asdf
jstoecker Aug 18, 2024
b3c2bd1
asdf
jstoecker Aug 18, 2024
dd429dd
asdf
jstoecker Aug 18, 2024
4a57549
asdf
jstoecker Aug 18, 2024
bf94fc6
asdf
jstoecker Aug 18, 2024
0eb3a52
asdf
jstoecker Aug 20, 2024
5e5c232
adsf
jstoecker Aug 20, 2024
01420da
asdf
jstoecker Aug 20, 2024
531be51
asdf
jstoecker Aug 20, 2024
c03b983
asdf
jstoecker Aug 20, 2024
9ffda04
asdf
jstoecker Aug 20, 2024
21747fe
asdf
jstoecker Aug 20, 2024
3f2bb76
asdf
jstoecker Aug 20, 2024
22e6231
asdf
jstoecker Aug 20, 2024
32c400b
asdf
jstoecker Aug 20, 2024
8effdc4
asdf
jstoecker Aug 20, 2024
e3cf8b5
asdf
jstoecker Aug 20, 2024
0a4c5fc
asdf
jstoecker Aug 20, 2024
93daa1a
asdf
jstoecker Aug 20, 2024
418bf8c
asdf
jstoecker Aug 20, 2024
c1ba8e5
fix fp16 buffer sizes
jstoecker Aug 20, 2024
93c790b
remove null terminator from std string
jstoecker Aug 21, 2024
6f43ded
fix null terminator
jstoecker Aug 21, 2024
b3f5a98
hold ref to type info
jstoecker Aug 21, 2024
0eee4c1
use io bindings
jstoecker Aug 21, 2024
53ebdaa
shape tweaks
jstoecker Aug 21, 2024
9aecfc4
remove help cl opt
jstoecker Aug 21, 2024
0360860
download model
jstoecker Aug 21, 2024
cdd7c7d
3p notices
jstoecker Aug 21, 2024
12e6b3d
cmake note
jstoecker Aug 21, 2024
7475b0d
pch and separate translation unit for helpers
jstoecker Aug 21, 2024
9b6122f
cleanup
jstoecker Aug 21, 2024
b2bc7fa
cleanup
jstoecker Aug 21, 2024
261bbc4
cleanup
jstoecker Aug 21, 2024
f73e609
readme update
jstoecker Aug 21, 2024
e877e6d
copyright
jstoecker Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
asdf
  • Loading branch information
jstoecker committed Aug 20, 2024
commit 531be51cf48eb4f6127acf22b28ae65da1368ccc
92 changes: 58 additions & 34 deletions Samples/DirectMLCV/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <dxcore.h>
#include <optional>
#include <iostream>
#include <filesystem>
#include <span>
#include <string>
#include "cxxopts.hpp"
Expand All @@ -25,37 +26,12 @@
#include "image_helpers.h"
#include "dx_helpers.h"

int main(int argc, char** argv)
void RunModel(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* d3dQueue,
const std::filesystem::path& modelPath,
const std::filesystem::path& imagePath)
{
using Microsoft::WRL::ComPtr;

// Functions in image_helpers.h use WIC APIs, which require CoInitialize.
THROW_IF_FAILED(CoInitializeEx(nullptr, COINIT_MULTITHREADED));

// Parse command-line options.
cxxopts::Options commandLineParams("directml_cv", "DirectML Computer Vision Sample");
commandLineParams.add_options()
("h,help", "Print usage")
("m,model", "Path to ONNX model file", cxxopts::value<std::string>()->default_value("esrgan.onnx"))
("i,image", "Path to input image file", cxxopts::value<std::string>()->default_value("zebra.jpg"))
("a,adapter", "Adapter name substring filter", cxxopts::value<std::string>()->default_value(""))
;

auto commandLineArgs = commandLineParams.parse(argc, argv);

// See dx_helpers.h for logic to select a DXCore adapter, create DML device, and create D3D command queue.
ComPtr<IDMLDevice> dmlDevice;
ComPtr<ID3D12CommandQueue> commandQueue;
try
{
std::tie(dmlDevice, commandQueue) = CreateDmlDeviceAndCommandQueue(commandLineArgs["adapter"].as<std::string>());
}
catch (const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}

// DML execution provider prefers these session options.
Ort::SessionOptions sessionOptions;
sessionOptions.DisableMemPattern();
Expand All @@ -68,13 +44,13 @@ int main(int argc, char** argv)
Ort::ThrowOnError(ortApi.GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ortDmlApi)));
Ort::ThrowOnError(ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(
sessionOptions,
dmlDevice.Get(),
commandQueue.Get()
dmlDevice,
d3dQueue
));

// Load ONNX model into a session.
Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "DirectML_CV");
Ort::Session ortSession(env, L"esrgan.onnx", sessionOptions);
Ort::Session ortSession(env, modelPath.wstring().c_str(), sessionOptions);

if (ortSession.GetInputCount() != 1 && ortSession.GetOutputCount() != 1)
{
Expand All @@ -95,7 +71,7 @@ int main(int argc, char** argv)
const uint32_t inputHeight = inputTensorShape[2];
const uint32_t inputWidth = inputTensorShape[3];
std::vector<std::byte> inputBuffer(inputChannels * inputHeight * inputWidth * sizeof(float));
FillNCHWBufferFromImageFilename(L"zebra.jpg", inputBuffer, inputHeight, inputWidth, DataType::Float32, ChannelOrder::RGB);
FillNCHWBufferFromImageFilename(imagePath.wstring(), inputBuffer, inputHeight, inputWidth, DataType::Float32, ChannelOrder::RGB);
SaveNCHWBufferToImageFilename(L"input.png", inputBuffer, inputHeight, inputWidth, DataType::Float32, ChannelOrder::RGB);

// For simplicity, this sample binds input/output buffers in system memory instead of DirectX resources.
Expand Down Expand Up @@ -138,9 +114,57 @@ int main(int argc, char** argv)
DataType::Float32,
ChannelOrder::RGB
);
}

int main(int argc, char** argv)
{
using Microsoft::WRL::ComPtr;

// Functions in image_helpers.h use WIC APIs, which require CoInitialize.
THROW_IF_FAILED(CoInitializeEx(nullptr, COINIT_MULTITHREADED));

// Parse command-line options.
cxxopts::Options commandLineParams("directml_cv", "DirectML Computer Vision Sample");
commandLineParams.add_options()
("h,help", "Print usage")
("m,model", "Path to ONNX model file", cxxopts::value<std::string>()->default_value("esrgan.onnx"))
("i,image", "Path to input image file", cxxopts::value<std::string>()->default_value("zebra.jpg"))
("a,adapter", "Adapter name substring filter", cxxopts::value<std::string>()->default_value(""));

auto commandLineArgs = commandLineParams.parse(argc, argv);

// See dx_helpers.h for logic to select a DXCore adapter, create DML device, and create D3D command queue.
ComPtr<IDMLDevice> dmlDevice;
ComPtr<ID3D12CommandQueue> commandQueue;
try
{
std::tie(dmlDevice, commandQueue) = CreateDmlDeviceAndCommandQueue(commandLineArgs["adapter"].as<std::string>());
}
catch (const std::exception& e)
{
std::cerr << "Error creating device: " << e.what() << std::endl;
return 1;
}

try
{
RunModel(
dmlDevice.Get(),
commandQueue.Get(),
commandLineArgs["model"].as<std::string>(),
commandLineArgs["image"].as<std::string>()
);
}
catch (const std::exception& e)
{
std::cerr << "Error running model: " << e.what() << std::endl;
return 1;
}

// Functions in image_helpers.h use WIC APIs, which require CoUninitialize.
CoUninitialize();

std::cout <<"done\n";

return 0;
}