-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
94c81cb
commit 925e0fe
Showing
3 changed files
with
345 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Flash attention Kernels using NKI on Neuron. Tested on trn1 & trn2.""" | ||
from functools import partial | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
# TODO(apoorvtintin) remove pytype disable when dependencies are public. | ||
# pytype: disable=import-error | ||
# Import needed to enable JAX cache on Neuron. | ||
import jax_neuronx # pylint: disable=unused-import | ||
import neuronxcc.nki.language as nl | ||
from jax import custom_vjp | ||
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd | ||
|
||
# pytype: enable=import-error | ||
|
||
Tensor = jax.Array | ||
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 | ||
|
||
|
||
@partial(custom_vjp, nondiff_argnums=(4, 5, 6)) | ||
def flash_attention( | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
bias: Tensor, | ||
causal: bool = False, | ||
softmax_scale: float = 1.0, | ||
dropout_rate: float = 0.0, | ||
): | ||
"""Wraps _mha_forward for custom vjp. | ||
Args: | ||
query: Query of shape [batch_size, target_length, num_heads, per_head_dim]. | ||
key: Key of shape [batch_size, source_length, num_heads, per_head_dim]. | ||
value: Value of shape [batch_size, source_length, num_heads, per_head_dim]. | ||
bias: Optional logit biases of shape [1, 1, target_length, source_length]. | ||
softmax_scale: Optional scale to apply to softmax. Defaults to 1. | ||
causal: Whether to apply causal mask. | ||
dropout_rate: Dropout rate. Default to 0.0 (no dropout). | ||
Returns: | ||
The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim]. | ||
""" | ||
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate) | ||
return out | ||
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate): | ||
"""Computes attention outputs following FlashAttention. | ||
See also `_mha_backward` for the backward pass. | ||
Args: | ||
query: Input query. | ||
key: Input key. | ||
value: Input value. | ||
bias: Input bias. | ||
causal: Input segment_ids. | ||
softmax_scale: Softmax scale to use in the kernel. | ||
dropout_rate: Dropout rate to use in the kernel. | ||
""" | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension. | ||
batch_size, _, num_heads, _ = query.shape | ||
|
||
# Transpose the query, key, and value tensors. | ||
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]. | ||
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]. | ||
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]. | ||
|
||
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias is not None: | ||
assert ( | ||
bias.ndim == 4 | ||
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
assert bias.shape[0] == 1 and bias.shape[1] == 1, ( | ||
f"Bias is only supported when batch and num_heads are both 1, " | ||
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}" | ||
) | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
) | ||
else: | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
) | ||
# Transpose the output back to the original shape. | ||
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]. | ||
|
||
return attn_output, (lse, attn_output, q, k, v, bias) | ||
|
||
|
||
def _mha_backward(causal, softmax_scale, dropout_rate, res, d_attn_output): | ||
lse, o, q, k, v, bias = res | ||
batch_size, num_heads, _, _ = q.shape | ||
|
||
# Transpose the input tensors. | ||
o = o.transpose(0, 2, 3, 1) | ||
dy = d_attn_output.transpose(0, 2, 3, 1) | ||
|
||
# Transpose v tensor. | ||
v = jnp.transpose(v, axes=(0, 1, 3, 2)) | ||
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias is not None: | ||
assert ( | ||
bias.ndim == 4 | ||
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
assert bias.shape[0] == 1 and bias.shape[1] == 1, ( | ||
f"Bias is only supported when batch and num_heads are both 1, " | ||
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}" | ||
) | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
softmax_scale=softmax_scale, | ||
) | ||
else: | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
softmax_scale=softmax_scale, | ||
) | ||
|
||
# Transpose the gradients back to the original shape. | ||
d_query = d_query.transpose(0, 3, 1, 2) | ||
d_key = d_key.transpose(0, 3, 1, 2) | ||
d_value = d_value.transpose(0, 3, 1, 2) | ||
|
||
return d_query, d_key, d_value, None | ||
|
||
|
||
flash_attention.defvjp(_mha_forward, _mha_backward) |
143 changes: 143 additions & 0 deletions
143
axlearn/common/flash_attention/neuron_attention_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2.""" | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from axlearn.common.flash_attention.utils import mha_reference | ||
|
||
if jax.default_backend() != "neuron": | ||
pytestmark = pytest.skip( | ||
reason="Incompatible hardware, AWS Neuron only test.", allow_module_level=True | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,seq_len,num_heads,per_head_dim", | ||
[ | ||
(1, 2048, 1, 64), | ||
(2, 2048, 2, 64), | ||
(1, 2048, 1, 128), | ||
(2, 2048, 2, 128), | ||
(1, 2048, 8, 128), | ||
(2, 2048, 8, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("attention_bias_type", [None, "4d"]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) | ||
def test_fwd_against_ref( | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
attention_bias_type: bool, | ||
): | ||
# On demand import only if test is needed. | ||
# pylint: disable=import-outside-toplevel | ||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
|
||
softmax_scale = 1.0 / (per_head_dim**0.5) | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
if attention_bias_type == "4d": | ||
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype) | ||
else: | ||
bias = None | ||
|
||
o = flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
) | ||
o_ref = mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
) | ||
if input_dtype == jnp.float16: | ||
chex.assert_trees_all_close(o, o_ref, atol=0.07) | ||
elif input_dtype == jnp.float32: | ||
chex.assert_trees_all_close(o, o_ref, atol=0.03) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,num_heads,seq_len,per_head_dim", | ||
[ | ||
(1, 1, 2048, 64), | ||
(2, 2, 2048, 64), | ||
(1, 1, 2048, 128), | ||
(2, 2, 2048, 128), | ||
(1, 8, 2048, 128), | ||
(2, 8, 2048, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) | ||
@pytest.mark.parametrize("attention_bias_type", [None, "2d"]) | ||
def test_bwd_against_ref( | ||
batch_size: int, | ||
num_heads: int, | ||
seq_len: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
attention_bias_type: bool, | ||
): | ||
# On demand import only if test is needed. | ||
# pylint: disable=import-outside-toplevel | ||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
|
||
softmax_scale = 1.0 / (per_head_dim**0.5) | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
if attention_bias_type == "2d": | ||
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype) | ||
else: | ||
bias = None | ||
segment_ids = None | ||
|
||
def fn(q, k, v, bias): | ||
return flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
).sum() | ||
|
||
def ref_fn(q, k, v, bias, segment_ids): | ||
return mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
segment_ids, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
).sum() | ||
|
||
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) | ||
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) | ||
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters