Skip to content

jax-ml/jax-triton

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax-triton

The jax-triton repository contains integrations between JAX and Triton.

This is not an officially supported Google product.

Installation

$ pip install jax-triton

Make sure you have a CUDA-compatible jaxlib installed. For example you could run:

$ pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Development

To develop jax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$ cd jax-triton
$ pip install -e .

To run the jax-triton tests, you'll need pytest and absl-py:

$ pip install pytest absl-py
$ pytest tests/