PyTorch with DirectML enables training and inference of complex machine learning models on a wide range of DirectX 12-compatible hardware. This is done through torch-directml
, a plugin for PyTorch.
DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported hardware and drivers, including all DirectX 12-capable GPUs from vendors such as AMD, Intel, NVIDIA, and Qualcomm.
More information about DirectML can be found on the DirectML Overview page on Microsoft Learn.
PyTorch with DirectML is supported on both the latest versions of Windows and the Windows Subsystem for Linux, and is available for download as a PyPI package. For more information about getting started with torch-directml
, see our Windows or WSL 2 guidance on Microsoft Learn.
Once a Python (3.8 to 3.12) environment is setup, install the latest release of torch-directml
by running the following command:
pip install torch-directml
Try the torch-directml
samples below, or explore the cv, transformer, llm and diffusion folders:
- attention is all you need - the original transformer model
- yolov3 - a real-time object detection model
- squeezenet - a small image classification model
- resnet50 - an image classification model
- maskrcnn - an object detection model
- llm - a text generation and chatbot app supporting various language models
- whisper - a general-purpose speech recognition model
- Stable Diffusion Turbo & XL Turbo - a text-to-image generation model