Skip to content

Commit

Permalink
Flash Attention for Neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Feb 7, 2025
1 parent 94c81cb commit cefc3c0
Show file tree
Hide file tree
Showing 3 changed files with 358 additions and 1 deletion.
192 changes: 192 additions & 0 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright © 2024 Amazon Inc.
"""Flash attention Kernels using NKI on Neuron. Tested on trn1 & trn2."""
from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp

# TODO(apoorvtintin) remove pytype disable when dependencies are public.
# pytype: disable=import-error
# Import needed to enable JAX cache on Neuron.
import jax_neuronx # pylint: disable=unused-import
import neuronxcc.nki.language as nl
from jax import custom_vjp
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd

# pytype: enable=import-error

Tensor = jax.Array
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1


@partial(custom_vjp, nondiff_argnums=(5, 6, 7))
def flash_attention(
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor] = None,
prng_key: Optional[Tensor] = None,
causal: bool = False,
softmax_scale: float = 1.0,
dropout_rate: float = 0.0,
):
"""Wraps _mha_forward for custom vjp.
Args:
query: Query of shape [batch_size, target_length, num_heads, per_head_dim].
key: Key of shape [batch_size, source_length, num_heads, per_head_dim].
value: Value of shape [batch_size, source_length, num_heads, per_head_dim].
bias: Optional logit biases of shape [1, 1, target_length, source_length].
prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0.
causal: Whether to apply causal mask.
softmax_scale: Optional scale to apply to softmax. Defaults to 1.
dropout_rate: Dropout rate. Default to 0.0 (no dropout).
Returns:
The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim].
"""
out, _ = _mha_forward(query, key, value, bias, prng_key, causal, softmax_scale, dropout_rate)
return out


def _mha_forward(query, key, value, bias, prng_key, causal, softmax_scale, dropout_rate):
"""Computes attention outputs following FlashAttention.
See also `_mha_backward` for the backward pass.
Args:
query: Input query.
key: Input key.
value: Input value.
bias: Input bias.
prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0.
causal: Input segment_ids.
softmax_scale: Softmax scale to use in the kernel.
dropout_rate: Dropout rate to use in the kernel.
"""
# Get the batch size, sequence lengths, number of heads, and hidden dimension.
batch_size, _, num_heads, _ = query.shape

# Transpose the query, key, and value tensors.
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len].
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len].
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model].

if dropout_rate > 0:
assert dropout_rate < 1
assert prng_key is not None
else:
# Dummy unused key.
prng_key = jax.random.key(0)

# TODO(apoorvtintin) Pass rbg key to kernel directly when kernel is ready to accept it.
# Currenlty NKI kernel supports a single 32 bit key, temporarily override this till support
# for 128 bit keys is added. Till then dropout is not supported.
prng_key = jnp.array([1])

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads.
if num_heads > 0 and num_heads % lnc == 0:
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias is not None:
assert (
bias.ndim == 4
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
assert bias.shape[0] == 1 and bias.shape[1] == 1, (
f"Bias is only supported when batch and num_heads are both 1, "
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}"
)
attn_output, lse = flash_fwd[grid](
q,
k,
v,
prng_key,
bias,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=dropout_rate,
)
else:
attn_output, lse = flash_fwd[grid](
q,
k,
v,
prng_key,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=dropout_rate,
)
# Transpose the output back to the original shape.
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model].

return attn_output, (lse, attn_output, q, k, v, bias, prng_key)


def _mha_backward(causal, softmax_scale, dropout_rate, res, d_attn_output):
lse, o, q, k, v, bias, prng_key = res
batch_size, num_heads, _, _ = q.shape

# Transpose the input tensors.
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor.
v = jnp.transpose(v, axes=(0, 1, 3, 2))

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads.
if num_heads > 0 and num_heads % lnc == 0:
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias is not None:
assert (
bias.ndim == 4
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
assert bias.shape[0] == 1 and bias.shape[1] == 1, (
f"Bias is only supported when batch and num_heads are both 1, "
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}"
)
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
prng_key,
bias,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=dropout_rate,
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
prng_key,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=dropout_rate,
softmax_scale=softmax_scale,
)

# Transpose the gradients back to the original shape.
d_query = d_query.transpose(0, 3, 1, 2) # [batch_size, q_seq_len, num_heads, d_model]
d_key = d_key.transpose(0, 3, 1, 2) # [batch_size, kv_seq_len, num_heads, d_model]
d_value = d_value.transpose(0, 3, 1, 2) # [batch_size, kv_seq_len, num_heads, d_model]

return d_query, d_key, d_value, None, None


flash_attention.defvjp(_mha_forward, _mha_backward)
143 changes: 143 additions & 0 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright © 2024 Amazon Inc.
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2."""

import chex
import jax
import jax.numpy as jnp
import pytest

from axlearn.common.flash_attention.utils import mha_reference

if jax.default_backend() != "neuron":
pytestmark = pytest.skip(
reason="Incompatible hardware, AWS Neuron only test.", allow_module_level=True
)


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("attention_bias_type", [None, "2d"])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
attention_bias_type: str,
):
# On demand import only if test is needed.
# pylint: disable=import-outside-toplevel
from axlearn.common.flash_attention.neuron_attention import flash_attention

softmax_scale = per_head_dim**-0.5
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

if attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype)
else:
bias = None

o = flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
)
o_ref = mha_reference(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
)
if input_dtype == jnp.float16:
chex.assert_trees_all_close(o, o_ref, atol=0.07)
elif input_dtype == jnp.float32:
chex.assert_trees_all_close(o, o_ref, atol=0.03)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.parametrize("attention_bias_type", [None, "2d"])
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
attention_bias_type: str,
):
# On demand import only if test is needed.
# pylint: disable=import-outside-toplevel
from axlearn.common.flash_attention.neuron_attention import flash_attention

softmax_scale = per_head_dim**-0.5
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

if attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype)
else:
bias = None
segment_ids = None

def fn(q, k, v, bias):
return flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
).sum()

def ref_fn(q, k, v, bias, segment_ids):
return mha_reference(
q,
k,
v,
bias,
segment_ids,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
).sum()

jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias)
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07)
24 changes: 23 additions & 1 deletion axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor:


def flash_attention_implementation(
backend: Literal["cpu", "tpu", "gpu", "xla"],
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"],
*,
softmax_scale: float,
block_size: int = 128,
Expand Down Expand Up @@ -275,6 +275,28 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
interpret=(backend == "cpu"),
)

elif backend == "neuron":
# pylint: disable=import-outside-toplevel
from axlearn.common.flash_attention.neuron_attention import (
flash_attention as neuron_flash_attention,
)

key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)

# other_biases includes SegmentIdAttentionBias among other biases.
causal, other_biases = split(bias, CausalAttentionBias)

return neuron_flash_attention(
query,
key,
value,
bias=other_biases.value(),
causal=causal.has_value(),
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
)

elif backend in ("cpu", "xla"):
key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)
Expand Down

0 comments on commit cefc3c0

Please sign in to comment.