Skip to content

Commit

Permalink
[OPS] enable flash_attention_v2 TMA (triton-lang#2544)
Browse files Browse the repository at this point in the history
  • Loading branch information
runseny authored Oct 26, 2023
1 parent 2323adb commit 4c816c2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 40 deletions.
2 changes: 1 addition & 1 deletion python/test/regression/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_elementwise(N, dtype_str):
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088,
}
}
Expand Down
4 changes: 0 additions & 4 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
# temporary env var control begin
os.putenv("ENABLE_TMA", "0")
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
Expand All @@ -55,5 +53,3 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
# temporary env var control end
os.putenv("ENABLE_TMA", enable_tma)
82 changes: 47 additions & 35 deletions python/triton/ops/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _bwd_kernel_one_col_block(
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block,
off_h, off_z, off_hz, start_n, num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
Expand All @@ -152,13 +152,21 @@ def _bwd_kernel_one_col_block(
else:
lo = 0

Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0))
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M, 0))
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0))
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo, 0))
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M, 0))
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M, 0))
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
DQ_offset = (off_z * stride_qz + off_h * stride_qh)
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
if SEQUENCE_PARALLEL:
DQ_offset += stride_dqa.to(tl.int64) * start_n
DQ_offset = DQ_offset // stride_qm

Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))

# initialize row/col offsets
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
Expand Down Expand Up @@ -232,6 +240,8 @@ def _bwd_kernel(
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
Z_H_N_CTX,
SQ_Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
Expand All @@ -243,69 +253,69 @@ def _bwd_kernel(
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh

if SEQUENCE_PARALLEL:
DQ += stride_dqa.to(tl.int64) * tl.program_id(1)

Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(N_CTX, BLOCK_DMODEL),
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(N_CTX, BLOCK_DMODEL),
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(N_CTX, BLOCK_DMODEL),
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
DO_block_ptr = tl.make_block_ptr(
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(N_CTX, BLOCK_DMODEL),
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
if SEQUENCE_PARALLEL:
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
else:
DQ_block_ptr = tl.make_block_ptr(
base=DQ,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)

DK_block_ptr = tl.make_block_ptr(
base=DK,
shape=(N_CTX, BLOCK_DMODEL),
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
DV_block_ptr = tl.make_block_ptr(
base=DV,
shape=(N_CTX, BLOCK_DMODEL),
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
Expand All @@ -326,7 +336,7 @@ def _bwd_kernel(
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
off_h, off_z, off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
Expand All @@ -346,7 +356,7 @@ def _bwd_kernel(
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
Z, H, N_CTX,
off_hz, start_n, num_block_n,
off_h, off_z, off_hz, start_n, num_block_n,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
Expand Down Expand Up @@ -429,6 +439,8 @@ def backward(ctx, do):
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
Expand Down

0 comments on commit 4c816c2

Please sign in to comment.