-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
94c81cb
commit cefc3c0
Showing
3 changed files
with
358 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Flash attention Kernels using NKI on Neuron. Tested on trn1 & trn2.""" | ||
from functools import partial | ||
from typing import Optional | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
# TODO(apoorvtintin) remove pytype disable when dependencies are public. | ||
# pytype: disable=import-error | ||
# Import needed to enable JAX cache on Neuron. | ||
import jax_neuronx # pylint: disable=unused-import | ||
import neuronxcc.nki.language as nl | ||
from jax import custom_vjp | ||
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd | ||
|
||
# pytype: enable=import-error | ||
|
||
Tensor = jax.Array | ||
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 | ||
|
||
|
||
@partial(custom_vjp, nondiff_argnums=(5, 6, 7)) | ||
def flash_attention( | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
bias: Optional[Tensor] = None, | ||
prng_key: Optional[Tensor] = None, | ||
causal: bool = False, | ||
softmax_scale: float = 1.0, | ||
dropout_rate: float = 0.0, | ||
): | ||
"""Wraps _mha_forward for custom vjp. | ||
Args: | ||
query: Query of shape [batch_size, target_length, num_heads, per_head_dim]. | ||
key: Key of shape [batch_size, source_length, num_heads, per_head_dim]. | ||
value: Value of shape [batch_size, source_length, num_heads, per_head_dim]. | ||
bias: Optional logit biases of shape [1, 1, target_length, source_length]. | ||
prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0. | ||
causal: Whether to apply causal mask. | ||
softmax_scale: Optional scale to apply to softmax. Defaults to 1. | ||
dropout_rate: Dropout rate. Default to 0.0 (no dropout). | ||
Returns: | ||
The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim]. | ||
""" | ||
out, _ = _mha_forward(query, key, value, bias, prng_key, causal, softmax_scale, dropout_rate) | ||
return out | ||
|
||
|
||
def _mha_forward(query, key, value, bias, prng_key, causal, softmax_scale, dropout_rate): | ||
"""Computes attention outputs following FlashAttention. | ||
See also `_mha_backward` for the backward pass. | ||
Args: | ||
query: Input query. | ||
key: Input key. | ||
value: Input value. | ||
bias: Input bias. | ||
prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0. | ||
causal: Input segment_ids. | ||
softmax_scale: Softmax scale to use in the kernel. | ||
dropout_rate: Dropout rate to use in the kernel. | ||
""" | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension. | ||
batch_size, _, num_heads, _ = query.shape | ||
|
||
# Transpose the query, key, and value tensors. | ||
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]. | ||
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]. | ||
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]. | ||
|
||
if dropout_rate > 0: | ||
assert dropout_rate < 1 | ||
assert prng_key is not None | ||
else: | ||
# Dummy unused key. | ||
prng_key = jax.random.key(0) | ||
|
||
# TODO(apoorvtintin) Pass rbg key to kernel directly when kernel is ready to accept it. | ||
# Currenlty NKI kernel supports a single 32 bit key, temporarily override this till support | ||
# for 128 bit keys is added. Till then dropout is not supported. | ||
prng_key = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if num_heads > 0 and num_heads % lnc == 0: | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias is not None: | ||
assert ( | ||
bias.ndim == 4 | ||
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
assert bias.shape[0] == 1 and bias.shape[1] == 1, ( | ||
f"Bias is only supported when batch and num_heads are both 1, " | ||
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}" | ||
) | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
prng_key, | ||
bias, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
) | ||
else: | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
prng_key, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
) | ||
# Transpose the output back to the original shape. | ||
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]. | ||
|
||
return attn_output, (lse, attn_output, q, k, v, bias, prng_key) | ||
|
||
|
||
def _mha_backward(causal, softmax_scale, dropout_rate, res, d_attn_output): | ||
lse, o, q, k, v, bias, prng_key = res | ||
batch_size, num_heads, _, _ = q.shape | ||
|
||
# Transpose the input tensors. | ||
o = o.transpose(0, 2, 3, 1) | ||
dy = d_attn_output.transpose(0, 2, 3, 1) | ||
|
||
# Transpose v tensor. | ||
v = jnp.transpose(v, axes=(0, 1, 3, 2)) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if num_heads > 0 and num_heads % lnc == 0: | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias is not None: | ||
assert ( | ||
bias.ndim == 4 | ||
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
assert bias.shape[0] == 1 and bias.shape[1] == 1, ( | ||
f"Bias is only supported when batch and num_heads are both 1, " | ||
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}" | ||
) | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
prng_key, | ||
bias, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
softmax_scale=softmax_scale, | ||
) | ||
else: | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
prng_key, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
softmax_scale=softmax_scale, | ||
) | ||
|
||
# Transpose the gradients back to the original shape. | ||
d_query = d_query.transpose(0, 3, 1, 2) # [batch_size, q_seq_len, num_heads, d_model] | ||
d_key = d_key.transpose(0, 3, 1, 2) # [batch_size, kv_seq_len, num_heads, d_model] | ||
d_value = d_value.transpose(0, 3, 1, 2) # [batch_size, kv_seq_len, num_heads, d_model] | ||
|
||
return d_query, d_key, d_value, None, None | ||
|
||
|
||
flash_attention.defvjp(_mha_forward, _mha_backward) |
143 changes: 143 additions & 0 deletions
143
axlearn/common/flash_attention/neuron_attention_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2.""" | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from axlearn.common.flash_attention.utils import mha_reference | ||
|
||
if jax.default_backend() != "neuron": | ||
pytestmark = pytest.skip( | ||
reason="Incompatible hardware, AWS Neuron only test.", allow_module_level=True | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,seq_len,num_heads,per_head_dim", | ||
[ | ||
(1, 2048, 1, 64), | ||
(2, 2048, 2, 64), | ||
(1, 2048, 1, 128), | ||
(2, 2048, 2, 128), | ||
(1, 2048, 8, 128), | ||
(2, 2048, 8, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("attention_bias_type", [None, "2d"]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) | ||
def test_fwd_against_ref( | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
attention_bias_type: str, | ||
): | ||
# On demand import only if test is needed. | ||
# pylint: disable=import-outside-toplevel | ||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
|
||
softmax_scale = per_head_dim**-0.5 | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
if attention_bias_type == "2d": | ||
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype) | ||
else: | ||
bias = None | ||
|
||
o = flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
) | ||
o_ref = mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
) | ||
if input_dtype == jnp.float16: | ||
chex.assert_trees_all_close(o, o_ref, atol=0.07) | ||
elif input_dtype == jnp.float32: | ||
chex.assert_trees_all_close(o, o_ref, atol=0.03) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,num_heads,seq_len,per_head_dim", | ||
[ | ||
(1, 1, 2048, 64), | ||
(2, 2, 2048, 64), | ||
(1, 1, 2048, 128), | ||
(2, 2, 2048, 128), | ||
(1, 8, 2048, 128), | ||
(2, 8, 2048, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) | ||
@pytest.mark.parametrize("attention_bias_type", [None, "2d"]) | ||
def test_bwd_against_ref( | ||
batch_size: int, | ||
num_heads: int, | ||
seq_len: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
attention_bias_type: str, | ||
): | ||
# On demand import only if test is needed. | ||
# pylint: disable=import-outside-toplevel | ||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
|
||
softmax_scale = per_head_dim**-0.5 | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
if attention_bias_type == "2d": | ||
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype) | ||
else: | ||
bias = None | ||
segment_ids = None | ||
|
||
def fn(q, k, v, bias): | ||
return flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
).sum() | ||
|
||
def ref_fn(q, k, v, bias, segment_ids): | ||
return mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
segment_ids, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
).sum() | ||
|
||
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) | ||
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) | ||
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters