diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index fb99b8fd3d43..ccb146e6b1d1 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -165,7 +165,7 @@ def test_elementwise(N, dtype_str): (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266, (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098, (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, - (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136, + (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159, (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088, } } diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index d6fcba77f58c..fabf454dc175 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -43,8 +43,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): ref_dq, q.grad = q.grad.clone(), None # # triton implementation tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par) - # temporary env var control begin - os.putenv("ENABLE_TMA", "0") tri_out.backward(dout) tri_dv, v.grad = v.grad.clone(), None tri_dk, k.grad = k.grad.clone(), None @@ -55,5 +53,3 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) - # temporary env var control end - os.putenv("ENABLE_TMA", enable_tma) diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index b041e82b0440..99d69c4d0170 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -140,7 +140,7 @@ def _bwd_kernel_one_col_block( stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, - off_hz, start_n, num_block, + off_h, off_z, off_hz, start_n, num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, @@ -152,13 +152,21 @@ def _bwd_kernel_one_col_block( else: lo = 0 - Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0)) - K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M, 0)) - V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M, 0)) - DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0)) - DQ_block_ptr = tl.advance(DQ_block_ptr, (lo, 0)) - DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M, 0)) - DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M, 0)) + Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm + DQ_offset = (off_z * stride_qz + off_h * stride_qh) + K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn + V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn + if SEQUENCE_PARALLEL: + DQ_offset += stride_dqa.to(tl.int64) * start_n + DQ_offset = DQ_offset // stride_qm + + Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0)) + K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0)) + DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0)) + DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0)) + DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0)) + DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0)) # initialize row/col offsets offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) @@ -232,6 +240,8 @@ def _bwd_kernel( stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, + Z_H_N_CTX, + SQ_Z_H_N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, @@ -243,21 +253,10 @@ def _bwd_kernel( off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_kz + off_h * stride_kh - DV += off_z * stride_vz + off_h * stride_vh - - if SEQUENCE_PARALLEL: - DQ += stride_dqa.to(tl.int64) * tl.program_id(1) Q_block_ptr = tl.make_block_ptr( base=Q, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -265,7 +264,7 @@ def _bwd_kernel( ) K_block_ptr = tl.make_block_ptr( base=K, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_kn, stride_kk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -273,7 +272,7 @@ def _bwd_kernel( ) V_block_ptr = tl.make_block_ptr( base=V, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_vn, stride_vk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -281,23 +280,34 @@ def _bwd_kernel( ) DO_block_ptr = tl.make_block_ptr( base=DO, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(0, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - DQ_block_ptr = tl.make_block_ptr( - base=DQ, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) + if SEQUENCE_PARALLEL: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + else: + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + DK_block_ptr = tl.make_block_ptr( base=DK, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_kn, stride_kk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -305,7 +315,7 @@ def _bwd_kernel( ) DV_block_ptr = tl.make_block_ptr( base=DV, - shape=(N_CTX, BLOCK_DMODEL), + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_vn, stride_vk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), @@ -326,7 +336,7 @@ def _bwd_kernel( stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, - off_hz, start_n, num_block_n, + off_h, off_z, off_hz, start_n, num_block_n, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, @@ -346,7 +356,7 @@ def _bwd_kernel( stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, - off_hz, start_n, num_block_n, + off_h, off_z, off_hz, start_n, num_block_n, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, @@ -429,6 +439,8 @@ def backward(ctx, do): k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.shape[0], q.shape[1], q.shape[2], + q.shape[0] * q.shape[1] * q.shape[2], + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, SEQUENCE_PARALLEL=sequence_parallel,