Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention for Neuron #939

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Flash Attention for Neuron
  • Loading branch information
apoorvtintin committed Feb 7, 2025
commit 2c9a285ac219abec8d1b57457f46948af3f3ea80
192 changes: 192 additions & 0 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright © 2024 Amazon Inc.
"""Flash attention Kernels using NKI on Neuron. Tested on trn1 & trn2."""
from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp

# TODO(apoorvtintin): remove pytype disable when dependencies are public.
# pytype: disable=import-error
# Import needed to enable JAX cache on Neuron.
import jax_neuronx # pylint: disable=unused-import
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved
import neuronxcc.nki.language as nl
from jax import custom_vjp
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd

# pytype: enable=import-error

Tensor = jax.Array
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1


@partial(custom_vjp, nondiff_argnums=(5, 6, 7))
def flash_attention(
query: Tensor,
key: Tensor,
value: Tensor,
bias: Optional[Tensor] = None,
prng_key: Optional[Tensor] = None,
causal: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we support segment ID? Or a more general masking fn (with optimized handling) is even better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, I am fine with leaving a TODO here, but it is a hard blocker for enabling it for our internal training.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do segment IDs in a separate PR? That involves non-trivial work and needs some time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, in this regard, I may ask for more, let's do general mask then, since we have want things beyond causal.

softmax_scale: float = 1.0,
dropout_rate: float = 0.0,
):
"""Wraps _mha_forward for custom vjp.

Args:
query: Query of shape [batch_size, target_length, num_heads, per_head_dim].
key: Key of shape [batch_size, source_length, num_heads, per_head_dim].
value: Value of shape [batch_size, source_length, num_heads, per_head_dim].
bias: Optional logit biases of shape [1, 1, target_length, source_length].
prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0.
causal: Whether to apply causal mask.
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved
softmax_scale: Optional scale to apply to softmax. Defaults to 1.
dropout_rate: Dropout rate. Default to 0.0 (no dropout).

Returns:
The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim].
"""
out, _ = _mha_forward(query, key, value, bias, prng_key, causal, softmax_scale, dropout_rate)
return out


def _mha_forward(query, key, value, bias, prng_key, causal, softmax_scale, dropout_rate):
"""Computes attention outputs following FlashAttention.

See also `_mha_backward` for the backward pass.

Args:
query: Input query.
key: Input key.
value: Input value.
bias: Input bias.
prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0.
causal: Input segment_ids.
softmax_scale: Softmax scale to use in the kernel.
dropout_rate: Dropout rate to use in the kernel.
"""
# Get the batch size, sequence lengths, number of heads, and hidden dimension.
batch_size, _, num_heads, _ = query.shape

# Transpose the query, key, and value tensors.
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len].
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len].
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model].

if dropout_rate > 0:
assert dropout_rate < 1
assert prng_key is not None
else:
# Dummy unused key.
prng_key = jax.random.key(0)

# TODO(apoorvtintin): Pass rbg key to kernel directly when kernel is ready to accept it.
# Currenlty NKI kernel supports a single 32 bit key, temporarily override this till support
# for 128 bit keys is added. Till then dropout is not supported.
prng_key = jnp.array([1])

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads.
if num_heads > 0 and num_heads % lnc == 0:
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias is not None:
assert (
bias.ndim == 4
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
assert bias.shape[0] == 1 and bias.shape[1] == 1, (
f"Bias is only supported when batch and num_heads are both 1, "
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}"
)
attn_output, lse = flash_fwd[grid](
q,
k,
v,
prng_key,
bias,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=dropout_rate,
)
else:
attn_output, lse = flash_fwd[grid](
q,
k,
v,
prng_key,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=dropout_rate,
)
# Transpose the output back to the original shape.
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model].

return attn_output, (lse, attn_output, q, k, v, bias, prng_key)


def _mha_backward(causal, softmax_scale, dropout_rate, res, d_attn_output):
lse, o, q, k, v, bias, prng_key = res
batch_size, num_heads, _, _ = q.shape

# Transpose the input tensors.
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor.
v = jnp.transpose(v, axes=(0, 1, 3, 2))

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads.
if num_heads > 0 and num_heads % lnc == 0:
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias is not None:
assert (
bias.ndim == 4
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
assert bias.shape[0] == 1 and bias.shape[1] == 1, (
f"Bias is only supported when batch and num_heads are both 1, "
f"batch is {bias.shape[0]} and num_heads is {bias.shape[1]}"
)
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
prng_key,
bias,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=dropout_rate,
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
prng_key,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=dropout_rate,
softmax_scale=softmax_scale,
)

# Transpose the gradients back to the original shape.
d_query = d_query.transpose(0, 3, 1, 2) # [batch_size, q_seq_len, num_heads, d_model]
d_key = d_key.transpose(0, 3, 1, 2) # [batch_size, kv_seq_len, num_heads, d_model]
d_value = d_value.transpose(0, 3, 1, 2) # [batch_size, kv_seq_len, num_heads, d_model]

return d_query, d_key, d_value, None, None


flash_attention.defvjp(_mha_forward, _mha_backward)
141 changes: 141 additions & 0 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright © 2024 Amazon Inc.
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2."""

import chex
import jax
import jax.numpy as jnp
import pytest

from axlearn.common.flash_attention.utils import mha_reference

if jax.default_backend() != "neuron":
pytestmark = pytest.skip(
reason="Incompatible hardware, AWS Neuron only test.", allow_module_level=True
)


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("attention_bias_type", [None, "2d"])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
attention_bias_type: str,
):
# On demand import only if test is needed.
# pylint: disable=import-outside-toplevel
from axlearn.common.flash_attention.neuron_attention import flash_attention

softmax_scale = per_head_dim**-0.5
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

if attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype)
else:
bias = None

o = flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
)
o_ref = mha_reference(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
)
if input_dtype == jnp.float16:
chex.assert_trees_all_close(o, o_ref, atol=0.07)
elif input_dtype == jnp.float32:
chex.assert_trees_all_close(o, o_ref, atol=0.03)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.parametrize("attention_bias_type", [None, "2d"])
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
attention_bias_type: str,
):
# On demand import only if test is needed.
# pylint: disable=import-outside-toplevel
from axlearn.common.flash_attention.neuron_attention import flash_attention

softmax_scale = per_head_dim**-0.5
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

if attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype)
else:
bias = None

def fn(q, k, v, bias):
return flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
).sum()

def ref_fn(q, k, v, bias):
return mha_reference(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
).sum()

jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias)
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07)
28 changes: 27 additions & 1 deletion axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor:


def flash_attention_implementation(
backend: Literal["cpu", "tpu", "gpu", "xla"],
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"],
*,
softmax_scale: float,
block_size: int = 128,
Expand Down Expand Up @@ -275,6 +275,32 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
interpret=(backend == "cpu"),
)

elif backend == "neuron":
apoorvtintin marked this conversation as resolved.
Show resolved Hide resolved
# pylint: disable=import-outside-toplevel
from axlearn.common.flash_attention.neuron_attention import (
flash_attention as neuron_flash_attention,
)

key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)

# other_biases includes SegmentIdAttentionBias among other biases.
causal, other_biases = split(bias, CausalAttentionBias)

# TODO(apoorvtintin): Remove this once dropout support in kernel is ready.
if dropout_rate > 0:
raise NotImplementedError("Backend Neuron does not have dropout support yet")

return neuron_flash_attention(
query,
key,
value,
bias=other_biases.value(),
causal=causal.has_value(),
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
)

elif backend in ("cpu", "xla"):
key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)
Expand Down