-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathtriton_test.py
104 lines (88 loc) · 3.82 KB
/
triton_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright 2022 The jax_triton Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import triton
import triton.language as tl
import jax
import jax.numpy as jnp
import jax_triton as jt
import numpy as np
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
BLOCK_SIZE: tl.constexpr,
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < 8
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
@triton.jit
def tanh_kernel(
x_ptr, # *Pointer* to first input vector
output_ptr, # *Pointer* to output vector
BLOCK_SIZE: tl.constexpr,
):
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < 8
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
output = tl.libdevice.tanh(x)
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
class TritonTest(absltest.TestCase):
def test_add_kernel(self):
def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
grid = lambda meta: (triton.cdiv(x.size, meta['BLOCK_SIZE']),)
return jt.triton_call(
x, y, kernel=add_kernel, out_shape=out_shape, grid=grid, BLOCK_SIZE=8)
x = jnp.arange(8, dtype=jnp.float32)
y = jnp.arange(8, dtype=jnp.float32)
np.testing.assert_allclose(add(x, y), x + y)
def test_tanh_kernel(self):
def tanh(x: jnp.ndarray) -> jnp.ndarray:
out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
grid = lambda meta: (triton.cdiv(x.size, meta['BLOCK_SIZE']),)
return jt.triton_call(
x, kernel=tanh_kernel, out_shape=out_shape, grid=grid, BLOCK_SIZE=8)
x = jnp.arange(8, dtype=jnp.float32)
np.testing.assert_allclose(tanh(x), np.tanh(x))
if __name__ == '__main__':
absltest.main()