Skip to content

Commit

Permalink
Move triton_kernel_call_lib to jaxlib
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 534934592
  • Loading branch information
sharadmv authored and The jax_triton Authors committed May 24, 2023
1 parent 97cb006 commit a9b499d
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 678 deletions.
24 changes: 0 additions & 24 deletions BUILD.bazel

This file was deleted.

61 changes: 0 additions & 61 deletions WORKSPACE

This file was deleted.

5 changes: 3 additions & 2 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from jax_triton.utils import next_power_of_2
from jax_triton.utils import strides_from_shape
from jax_triton.triton_lib import triton_call
from jax_triton.triton_lib import triton_kernel_call_lib
from jax_triton.version import __version__
from jax_triton.version import __version_info__
from jax_triton import pallas
from jax._src.lib import gpu_triton

get_compute_capability = triton_kernel_call_lib.get_compute_capability
get_compute_capability = gpu_triton.get_compute_capability

# trailer
del gpu_triton
2 changes: 1 addition & 1 deletion jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax._src import state
from jax._src import util
from jax._src.lax.control_flow import for_loop
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.state import AbstractRef
Expand All @@ -41,7 +42,6 @@
from jax.interpreters import partial_eval as pe
from jax.lib import xla_client as xc
import jax.numpy as jnp
from jax_triton import triton_kernel_call_lib
from jax_triton import utils as triton_utils
from jax_triton.pallas import core as pallas_core
from jax_triton.pallas import pallas_call_p
Expand Down
8 changes: 2 additions & 6 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,27 @@
# limitations under the License.

"""Module for calling Triton kernels from JAX."""
import collections
import functools
import math
import os
import types
import weakref

from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
import weakref

from absl import logging
import jax
from jax import core
import jaxlib
from jax import tree_util
from jax._src import core
from jax._src import state
from jax._src import util
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import jax.dlpack
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.lib import xla_client as xc
import jax.numpy as jnp
from jax_triton import triton_kernel_call_lib
from jax_triton import utils
import numpy as np

Expand Down
Loading

0 comments on commit a9b499d

Please sign in to comment.