Skip to content

Commit

Permalink
Refactor GPU lowering into one file
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 510244602
  • Loading branch information
sharadmv authored and The jax_triton Authors committed Feb 16, 2023
1 parent 603bc10 commit c42a887
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 158 deletions.
6 changes: 3 additions & 3 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

"""Library for JAX-Triton integrations."""
from jax_triton.triton_lib import cdiv
from jax_triton.triton_lib import next_power_of_2
from jax_triton.triton_lib import strides_from_shape
from jax_triton.utils import cdiv
from jax_triton.utils import next_power_of_2
from jax_triton.utils import strides_from_shape
from jax_triton.triton_lib import triton_call
from jax_triton.triton_lib import triton_kernel_call_lib
from jax_triton.version import __version__
Expand Down
9 changes: 8 additions & 1 deletion jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Module for pallas, a jaxpr "dialect" for Triton."""
from jax_triton.pallas.core import BlockSpec
from jax_triton.pallas.pallas_call import pallas_call
from jax_triton.pallas.pallas_call import clear_caches
from jax_triton.pallas.pallas_call import pallas_call_p
from jax_triton.pallas.primitives import atomic_add
from jax_triton.pallas.primitives import atomic_and
from jax_triton.pallas.primitives import atomic_cas
Expand All @@ -33,3 +33,10 @@
from jax_triton.pallas.primitives import program_id
from jax_triton.pallas.primitives import store
from jax_triton.pallas.primitives import swap
from jax_triton.utils import cdiv

try:
from jax_triton.pallas import triton_ir_lowering
del triton_ir_lowering
except (ImportError, ModuleNotFoundError):
pass
107 changes: 1 addition & 106 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.lib import xla_client as xc
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src import state
from jax._src.util import (
Expand All @@ -41,11 +38,7 @@
import jax.numpy as jnp
import numpy as np

from triton._C.libtriton import triton as tc

from jax_triton import triton_kernel_call_lib
from jax_triton.triton_lib import avals_to_layouts, normalize_grid
from jax_triton.pallas import lowering
from jax_triton.utils import avals_to_layouts, normalize_grid
from jax_triton.pallas import core as pallas_core

map, unsafe_map = safe_map, map
Expand Down Expand Up @@ -287,98 +280,6 @@ def _pallas_call_batching_rule(args, dims, *,
return out, (0,) * len(out)
batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule

class TritonCompilationResult(NamedTuple):
name: str
asm: Dict[str, str]
shared_mem: int
lowering_result: lowering.TritonLoweringResult

@weakref_lru_cache
def _compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec,
name: str, num_warps: int, num_stages: int
) -> TritonCompilationResult:
lowering_result = lowering.lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name)
backend = tc.runtime.backend.CUDA
device = 0
name, asm, shared_mem = tc.code_gen.compile_ttir(backend, lowering_result.module, device,
num_warps, num_stages, {}, 0)
return TritonCompilationResult(name, asm, shared_mem, lowering_result)


def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_spec: GridSpec,
**compiler_params: Any):
if interpret:
return mlir.lower_fun(_pallas_call_impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_spec=grid_spec, **compiler_params)
num_warps = compiler_params.get("num_warps", 4)
num_stages = compiler_params.get("num_stages", 3)
compilation_result = _compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)),
grid_spec, name, num_warps, num_stages)
name = compilation_result.name
asm = compilation_result.asm
shared_mem = compilation_result.shared_mem
if debug:
print(jaxpr)
print(grid_spec)
lowering_result = compilation_result.lowering_result
if debug:
lowering_result.module.print()
out_type = ir.TupleType.get_tuple([
ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype))
for out_shape in ctx.avals_out])
i32_type = ir.IntegerType.get_signless(32)

kernel = triton_kernel_call_lib.TritonKernel(
asm["cubin"], name, num_warps, shared_mem
)

grid = normalize_grid(compilation_result.lowering_result.grid, metaparams={})
# All arguments are buffers.
all_args = [None] * (len(in_shapes) + len(out_shapes))
zeroed_outputs = {} # TODO(cjfj): Expose through user API.
kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel, grid[0], grid[1], grid[2], all_args, zeroed_outputs
)

ctx.module_context.add_keepalive(kernel_call)
output_operand_aliases = ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
output_tuple_indices=[output],
operand_index=input,
operand_tuple_indices=[])
for input, output in input_output_aliases
])
out = mhlo.CustomCallOp(
[out_type],
in_nodes,
call_target_name=ir.StringAttr.get("triton_kernel_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(kernel_call.descriptor),
api_version=ir.IntegerAttr.get(i32_type, 1),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=avals_to_layouts(ctx.avals_in),
result_layouts=avals_to_layouts(ctx.avals_out),
output_operand_aliases=output_operand_aliases,
)
results = [mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
for i in range(len(out_shapes))]
return results
mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda")

@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
primitive_name: Optional[str] = None):
Expand All @@ -388,10 +289,6 @@ def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
jaxpr = for_loop._hoist_consts_to_refs(jaxpr)
return jaxpr, consts, out_tree()

def clear_caches():
_initial_style_open_jaxpr.cache_clear()
_compile_jaxpr.cache_clear()

def _preprocess_grid(grid: Optional[Union[Grid, int]]) -> Grid:
if grid is None:
return ()
Expand Down Expand Up @@ -430,8 +327,6 @@ def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
interpret: bool = False,
name: Optional[str] = None,
**compiler_params: Any):
xc.register_custom_call_target(
"triton_kernel_call", triton_kernel_call_lib.get_custom_call(), platform="CUDA")
if grid is None:
if in_specs is not None:
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
import functools

from typing import Any, Optional, Tuple, Sequence
from typing import Any, Dict, Optional, NamedTuple, Tuple, Sequence

import jax
from jax import api_util
Expand All @@ -28,30 +28,40 @@
from jax._src.lax.control_flow import for_loop
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import mlir
from jax._src import core as jax_core
from jax._src import pjit
from jax._src import state
from jax.lib import xla_client as xc
from jax._src.state import primitives as sp
from jax._src.state import discharge
from jax._src.state import ShapedArrayRef
from jax_triton.triton_lib import get_triton_type
from jax._src.util import weakref_lru_cache
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import jax.numpy as jnp
import numpy as np
import triton
import triton.language as tl
import numpy as np
from triton.language import ir as tl_ir
import triton._C.libtriton.triton as _triton

import jax_triton as jt
from jax_triton import triton_kernel_call_lib
from jax_triton import utils as triton_utils
from jax_triton.pallas import primitives
from jax_triton.pallas import core as pallas_core
from jax_triton.pallas import pallas_call_p
from jax_triton.triton_lib import get_triton_type

map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
partial = functools.partial

GridSpec = pallas_core.GridSpec
Grid = Tuple[int, ...]
BlockSpec = pallas_core.BlockSpec
BlockMapping = pallas_core.BlockMapping
GridSpec = pallas_core.GridSpec


# # General lowering logic

Expand Down Expand Up @@ -88,6 +98,13 @@ class TritonLoweringResult:
builder: tl_ir.builder
grid: Tuple[int, ...]

@dataclasses.dataclass
class TritonCompilationResult:
name: str
asm: Dict[str, str]
shared_mem: int
lowering_result: TritonLoweringResult

def _eval_index_map(ctx: TritonModuleContext, idx, block_mapping: Optional[BlockMapping]):
if block_mapping is None:
return None
Expand Down Expand Up @@ -401,7 +418,7 @@ def _offset_ptr(ptr, block_info: Optional[BlockInfo], idx: primitives.NDIndexer,
full_shape = block_info.full_shape.shape
num_mapped_dims = sum(b is pallas_core.mapped for b in block_info.block_shape)
block_shape = block_info.block_shape
strides = jt.strides_from_shape(full_shape)
strides = triton_utils.strides_from_shape(full_shape)
indexer_shape = idx.get_indexer_shape()
indices = idx.indices
other_shape = indexer_shape[len(idx.int_indexer_shape):]
Expand Down Expand Up @@ -770,3 +787,104 @@ def _while_lowering_rule(ctx: TritonLoweringRuleContext, *args, cond_nconsts,
post_args.append(phi_arg)
return post_args
triton_lowering_rules[lax.while_p] = _while_lowering_rule

@weakref_lru_cache
def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec,
name: str, num_warps: int, num_stages: int
) -> TritonCompilationResult:
lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name)
backend = _triton.runtime.backend.CUDA
device = 0
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, lowering_result.module, device,
num_warps, num_stages, {}, 0)
return TritonCompilationResult(name, asm, shared_mem, lowering_result)

@weakref_lru_cache
def compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec,
name: str, num_warps: int, num_stages: int
) -> TritonCompilationResult:
lowering_result = lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name)
backend = _triton.runtime.backend.CUDA
device = 0
name, asm, shared_mem = _triton.code_gen.compile_ttir(
backend, lowering_result.module, device, num_warps, num_stages, {}, 0)
return TritonCompilationResult(name, asm, shared_mem, lowering_result)


def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_spec: GridSpec,
**compiler_params: Any):
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_spec=grid_spec, **compiler_params)
num_warps = compiler_params.get("num_warps", 4)
num_stages = compiler_params.get("num_stages", 3)
compilation_result = compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)),
grid_spec, name, num_warps, num_stages)
name = compilation_result.name
asm = compilation_result.asm
shared_mem = compilation_result.shared_mem
if debug:
print(jaxpr)
print(grid_spec)
lowering_result = compilation_result.lowering_result
if debug:
lowering_result.module.print()
out_type = ir.TupleType.get_tuple([
ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype))
for out_shape in ctx.avals_out])
i32_type = ir.IntegerType.get_signless(32)

kernel = triton_kernel_call_lib.TritonKernel(
asm["cubin"], name, num_warps, shared_mem
)

grid = triton_utils.normalize_grid(
compilation_result.lowering_result.grid, metaparams={})
# All arguments are buffers.
all_args = [None] * (len(in_shapes) + len(out_shapes))
zeroed_outputs = {} # TODO(cjfj): Expose through user API.
kernel_call = triton_kernel_call_lib.TritonKernelCall(
kernel, grid[0], grid[1], grid[2], all_args, zeroed_outputs
)

ctx.module_context.add_keepalive(kernel_call)
output_operand_aliases = ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
output_tuple_indices=[output],
operand_index=input,
operand_tuple_indices=[])
for input, output in input_output_aliases
])
out = mhlo.CustomCallOp(
[out_type],
in_nodes,
call_target_name=ir.StringAttr.get("triton_kernel_call"),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(kernel_call.descriptor),
api_version=ir.IntegerAttr.get(i32_type, 1),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=triton_utils.avals_to_layouts(ctx.avals_in),
result_layouts=triton_utils.avals_to_layouts(ctx.avals_out),
output_operand_aliases=output_operand_aliases,
)
results = [mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
for i in range(len(out_shapes))]
return results
mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda")

xc.register_custom_call_target(
"triton_kernel_call", triton_kernel_call_lib.get_custom_call(), platform="CUDA")
Loading

0 comments on commit c42a887

Please sign in to comment.