Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ sample for ESRGAN #634

Merged
merged 47 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9343cae
wip
jstoecker Aug 17, 2024
9661b87
wip
jstoecker Aug 17, 2024
714269e
wip
jstoecker Aug 17, 2024
8ffaf9e
asdf
jstoecker Aug 17, 2024
5933c34
asdf
jstoecker Aug 17, 2024
be7b0ff
asdf
jstoecker Aug 17, 2024
91f3bbd
asdf
jstoecker Aug 17, 2024
4ecef23
asdf
jstoecker Aug 17, 2024
bcb5f94
asdf
jstoecker Aug 18, 2024
8a8587d
asdf
jstoecker Aug 18, 2024
77e014a
asdf
jstoecker Aug 18, 2024
f5d6743
asdf
jstoecker Aug 18, 2024
b3c2bd1
asdf
jstoecker Aug 18, 2024
dd429dd
asdf
jstoecker Aug 18, 2024
4a57549
asdf
jstoecker Aug 18, 2024
bf94fc6
asdf
jstoecker Aug 18, 2024
0eb3a52
asdf
jstoecker Aug 20, 2024
5e5c232
adsf
jstoecker Aug 20, 2024
01420da
asdf
jstoecker Aug 20, 2024
531be51
asdf
jstoecker Aug 20, 2024
c03b983
asdf
jstoecker Aug 20, 2024
9ffda04
asdf
jstoecker Aug 20, 2024
21747fe
asdf
jstoecker Aug 20, 2024
3f2bb76
asdf
jstoecker Aug 20, 2024
22e6231
asdf
jstoecker Aug 20, 2024
32c400b
asdf
jstoecker Aug 20, 2024
8effdc4
asdf
jstoecker Aug 20, 2024
e3cf8b5
asdf
jstoecker Aug 20, 2024
0a4c5fc
asdf
jstoecker Aug 20, 2024
93daa1a
asdf
jstoecker Aug 20, 2024
418bf8c
asdf
jstoecker Aug 20, 2024
c1ba8e5
fix fp16 buffer sizes
jstoecker Aug 20, 2024
93c790b
remove null terminator from std string
jstoecker Aug 21, 2024
6f43ded
fix null terminator
jstoecker Aug 21, 2024
b3f5a98
hold ref to type info
jstoecker Aug 21, 2024
0eee4c1
use io bindings
jstoecker Aug 21, 2024
53ebdaa
shape tweaks
jstoecker Aug 21, 2024
9aecfc4
remove help cl opt
jstoecker Aug 21, 2024
0360860
download model
jstoecker Aug 21, 2024
cdd7c7d
3p notices
jstoecker Aug 21, 2024
12e6b3d
cmake note
jstoecker Aug 21, 2024
7475b0d
pch and separate translation unit for helpers
jstoecker Aug 21, 2024
9b6122f
cleanup
jstoecker Aug 21, 2024
b2bc7fa
cleanup
jstoecker Aug 21, 2024
261bbc4
cleanup
jstoecker Aug 21, 2024
f73e609
readme update
jstoecker Aug 21, 2024
e877e6d
copyright
jstoecker Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adsf
  • Loading branch information
jstoecker committed Aug 20, 2024
commit 5e5c232c24f46172740c607c2715b869e6bc612d
17 changes: 16 additions & 1 deletion Samples/DirectMLCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ add_library(dxheaders INTERFACE)
target_include_directories(dxheaders INTERFACE ${dxheaders_SOURCE_DIR}/include/directx)
target_link_libraries(dxheaders INTERFACE Microsoft::DirectX-Guids)

# -----------------------------------------------------------------------------
# cxxopts - for parsing command line arguments
# -----------------------------------------------------------------------------

FetchContent_Declare(
cxxopts
GIT_REPOSITORY https://github.com/jarro2783/cxxopts
GIT_TAG v3.2.1
)

set(CXXOPTS_BUILD_EXAMPLES OFF CACHE INTERNAL "Set to ON to build examples")
set(CXXOPTS_BUILD_TESTS OFF CACHE INTERNAL "Set to ON to build tests")
set(CXXOPTS_ENABLE_INSTALL OFF CACHE INTERNAL "Generate the install target")
FetchContent_MakeAvailable(cxxopts)

# -----------------------------------------------------------------------------
# directml
# -----------------------------------------------------------------------------
Expand All @@ -88,7 +103,7 @@ target_link_libraries(dml INTERFACE "${dml_bin_dir}/directml.lib")
# -----------------------------------------------------------------------------

add_executable(directml_cv main.cpp)
target_link_libraries(directml_cv PRIVATE wil ort dml d3d12 dxcore dxheaders)
target_link_libraries(directml_cv PRIVATE wil ort dml d3d12 dxcore dxheaders cxxopts)
target_compile_features(directml_cv PRIVATE cxx_std_20)
target_include_directories(directml_cv PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src")

Expand Down
16 changes: 16 additions & 0 deletions Samples/DirectMLCV/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@


# Build

ARM64:
```
cmake --preset win-arm64
cmake --build preset win-arm64-release
```


# Run

```
> directml_cv.exe --device npu
```
105 changes: 105 additions & 0 deletions Samples/DirectMLCV/dx_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <dxcore.h>
#include <wrl/client.h>
#include <wil/result.h>

Microsoft::WRL::ComPtr<IDXCoreAdapter> SelectAdapter(std::string_view adapterNameFilter = "")
{
using Microsoft::WRL::ComPtr;

ComPtr<IDXCoreAdapterFactory> adapterFactory;
THROW_IF_FAILED(DXCoreCreateAdapterFactory(IID_PPV_ARGS(adapterFactory.GetAddressOf())));

// First try getting all GENERIC_ML devices, which is the broadest set of adapters
// and includes both GPUs and NPUs; however, running this sample on an older build of
// Windows may not have drivers that report GENERIC_ML.
ComPtr<IDXCoreAdapterList> adapterList;
THROW_IF_FAILED(adapterFactory->CreateAdapterList(
1,
&DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML,
adapterList.GetAddressOf()
));

// Fall back to CORE_COMPUTE if GENERIC_ML devices are not available. This is a more restricted
// set of adapters and may filter out some NPUs.
if (adapterList->GetAdapterCount() == 0)
{
THROW_IF_FAILED(adapterFactory->CreateAdapterList(
1,
&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE,
adapterList.GetAddressOf()
));
}

// Sort the adapters by preference, with hardware and high-performance adapters first.
DXCoreAdapterPreference preferences[] =
{
DXCoreAdapterPreference::Hardware,
DXCoreAdapterPreference::HighPerformance
};

THROW_IF_FAILED(adapterList->Sort(_countof(preferences), preferences));

ComPtr<IDXCoreAdapter> selectedAdapter;
for (uint32_t i = 0; i < adapterList->GetAdapterCount(); i++)
{
ComPtr<IDXCoreAdapter> adapter;
THROW_IF_FAILED(adapterList->GetAdapter(i, adapter.ReleaseAndGetAddressOf()));

size_t descriptionSize;
THROW_IF_FAILED(adapter->GetPropertySize(
DXCoreAdapterProperty::DriverDescription,
&descriptionSize
));

std::string adapterDescription(descriptionSize, '\0');
THROW_IF_FAILED(adapter->GetProperty(
DXCoreAdapterProperty::DriverDescription,
descriptionSize,
adapterDescription.data()
));

std::string selectedText = "";

// Use the first adapter matching the name filter.
if (!selectedAdapter && adapterDescription.find(adapterNameFilter) != std::string::npos)
{
selectedAdapter = adapter;
selectedText = " (SELECTED)";
}

std::cout << "Adapter[" << i << "]: " << adapterDescription << selectedText << std::endl;
}

if (!selectedAdapter)
{
throw std::runtime_error("No suitable adapters found");
}

return selectedAdapter;
}

std::tuple<Microsoft::WRL::ComPtr<IDMLDevice>, Microsoft::WRL::ComPtr<ID3D12CommandQueue>> CreateDmlDeviceAndCommandQueue()
{
using Microsoft::WRL::ComPtr;

ComPtr<IDXCoreAdapter> adapter = SelectAdapter();

ComPtr<ID3D12Device> d3d12Device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_12_0, IID_PPV_ARGS(&d3d12Device)));

ComPtr<IDMLDevice> dmlDevice;
THROW_IF_FAILED(DMLCreateDevice(d3d12Device.Get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dmlDevice)));

D3D12_COMMAND_QUEUE_DESC queueDesc =
{
.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE,
.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL,
.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE,
.NodeMask = 0
};

ComPtr<ID3D12CommandQueue> commandQueue;
THROW_IF_FAILED(d3d12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&commandQueue)));

return { dmlDevice, commandQueue };
}
133 changes: 20 additions & 113 deletions Samples/DirectMLCV/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,132 +22,38 @@
#include "onnxruntime_cxx_api.h"
#include "dml_provider_factory.h"
#include "image_helpers.h"
#include "dx_helpers.h"

using Microsoft::WRL::ComPtr;

ComPtr<IDXCoreAdapter> SelectAdapter(std::string_view adapterNameFilter = "")
{
ComPtr<IDXCoreAdapterFactory> adapterFactory;
THROW_IF_FAILED(DXCoreCreateAdapterFactory(IID_PPV_ARGS(adapterFactory.GetAddressOf())));

// First try getting all GENERIC_ML devices, which is the broadest set of adapters
// and includes both GPUs and NPUs; however, running this sample on an older build of
// Windows may not have drivers that report GENERIC_ML.
ComPtr<IDXCoreAdapterList> adapterList;
THROW_IF_FAILED(adapterFactory->CreateAdapterList(
1,
&DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML,
adapterList.GetAddressOf()
));

// Fall back to CORE_COMPUTE if GENERIC_ML devices are not available. This is a more restricted
// set of adapters and may filter out some NPUs.
if (adapterList->GetAdapterCount() == 0)
{
THROW_IF_FAILED(adapterFactory->CreateAdapterList(
1,
&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE,
adapterList.GetAddressOf()
));
}

// Sort the adapters by preference, with hardware and high-performance adapters first.
DXCoreAdapterPreference preferences[] =
{
DXCoreAdapterPreference::Hardware,
DXCoreAdapterPreference::HighPerformance
};

THROW_IF_FAILED(adapterList->Sort(_countof(preferences), preferences));

ComPtr<IDXCoreAdapter> selectedAdapter;
for (uint32_t i = 0; i < adapterList->GetAdapterCount(); i++)
{
ComPtr<IDXCoreAdapter> adapter;
THROW_IF_FAILED(adapterList->GetAdapter(i, adapter.ReleaseAndGetAddressOf()));

size_t descriptionSize;
THROW_IF_FAILED(adapter->GetPropertySize(
DXCoreAdapterProperty::DriverDescription,
&descriptionSize
));

std::string adapterDescription(descriptionSize, '\0');
THROW_IF_FAILED(adapter->GetProperty(
DXCoreAdapterProperty::DriverDescription,
descriptionSize,
adapterDescription.data()
));

std::string selectedText = "";

// Use the first adapter matching the name filter.
if (!selectedAdapter && adapterDescription.find(adapterNameFilter) != std::string::npos)
{
selectedAdapter = adapter;
selectedText = " (SELECTED)";
}

std::cout << "Adapter[" << i << "]: " << adapterDescription << selectedText << std::endl;
}

if (!selectedAdapter)
{
throw std::runtime_error("No suitable adapters found");
}

return selectedAdapter;
}

std::tuple<ComPtr<IDMLDevice>, ComPtr<ID3D12CommandQueue>> CreateDmlDeviceAndCommandQueue()
int main(int argc, char** argv)
{
ComPtr<IDXCoreAdapter> adapter = SelectAdapter();

ComPtr<ID3D12Device> d3d12Device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_12_0, IID_PPV_ARGS(&d3d12Device)));
using Microsoft::WRL::ComPtr;

ComPtr<IDMLDevice> dmlDevice;
THROW_IF_FAILED(DMLCreateDevice(d3d12Device.Get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dmlDevice)));

D3D12_COMMAND_QUEUE_DESC queueDesc =
{
.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE,
.Priority = D3D12_COMMAND_QUEUE_PRIORITY_NORMAL,
.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE,
.NodeMask = 0
};

ComPtr<ID3D12CommandQueue> commandQueue;
THROW_IF_FAILED(d3d12Device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&commandQueue)));
// Functions in image_helpers.h use WIC APIs, which require CoInitialize.
THROW_IF_FAILED(CoInitializeEx(nullptr, COINIT_MULTITHREADED));

return { dmlDevice, commandQueue };
}
// See dx_helpers.h for logic to select a DXCore adapter, create DML device, and create D3D command queue.
auto [dmlDevice, commandQueue] = CreateDmlDeviceAndCommandQueue();

Ort::Session CreateOnnxRuntimeSession(Ort::Env& env, IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, std::wstring_view modelPath)
{
const OrtApi& ortApi = Ort::GetApi();

// DML execution provider prefers these session options.
Ort::SessionOptions sessionOptions;
sessionOptions.DisableMemPattern();
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);

// By passing in an explicitly created DML device & queue, the DML execution provider sends work
// to the desired device. If not used, the DML execution provider will create its own device & queue.
const OrtApi& ortApi = Ort::GetApi();
const OrtDmlApi* ortDmlApi = nullptr;
Ort::ThrowOnError(ortApi.GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ortDmlApi)));
Ort::ThrowOnError(ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(sessionOptions, dmlDevice, commandQueue));

return Ort::Session(env, modelPath.data(), sessionOptions);
}

int main(int argc, char** argv)
{
THROW_IF_FAILED(CoInitializeEx(nullptr, COINIT_MULTITHREADED));

auto [dmlDevice, commandQueue] = CreateDmlDeviceAndCommandQueue();

const OrtApi& ortApi = Ort::GetApi();
Ort::ThrowOnError(ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(
sessionOptions,
dmlDevice.Get(),
commandQueue.Get()
));

// Load ONNX model into a session.
Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "DirectML_CV");
auto ortSession = CreateOnnxRuntimeSession(env, dmlDevice.Get(), commandQueue.Get(), LR"(esrgan.onnx)");
Ort::Session ortSession(env, L"esrgan.onnx", sessionOptions);

if (ortSession.GetInputCount() != 1 && ortSession.GetOutputCount() != 1)
{
Expand Down Expand Up @@ -195,10 +101,10 @@ int main(int argc, char** argv)
const uint32_t outputHeight = outputTensorShape[2];
const uint32_t outputWidth = outputTensorShape[3];

// Run the session to get inference results.
Ort::RunOptions runOpts;
std::vector<const char*> inputNames = { "image" };
std::vector<const char*> outputNames = { "output_0" };

auto outputs = ortSession.Run(runOpts, inputNames.data(), &inputTensor, 1, outputNames.data(), 1);

std::span<const std::byte> outputBuffer(reinterpret_cast<const std::byte*>(outputs[0].GetTensorData<float>()), outputChannels * outputHeight * outputWidth * sizeof(float));
Expand All @@ -212,6 +118,7 @@ int main(int argc, char** argv)
ChannelOrder::RGB
);

// Functions in image_helpers.h use WIC APIs, which require CoUninitialize.
CoUninitialize();

return 0;
Expand Down