Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaskFnAttentionBias._bool_value passes the same rank position tensors to mask_fn. #888

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor:
x = f(jnp.asarray([1,2]), jnp.asarray([3,4]))
assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None]
```
* Both tensors have the same rank (either 2 or 3), as batch dim is optional.
* If given non-scalar arguments of different shapes, the result must be the same if we
first broadcast the arguments against each other to make them have the same shape.
* Beyond requiring broadcastability, must not impose any constraints on the shapes of its
Expand Down Expand Up @@ -494,14 +495,14 @@ def _bool_value(self) -> Optional[Tensor]:
NotImplementedError. If `target_positions.ndim not in [1,2]`.
"""
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
# Shape: [batch, target_len, source_len].
# Shape: [1, target_len, 1], [1, 1, source_len].
target_positions, source_positions = target_positions[None], source_positions[None]
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim not in [1, 2]:
raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.")
if target_positions.ndim == 1:
# Shape: [batch, target_len].
# Shape: [batch, 1] + [target_len] = [batch, target_len]
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
elif target_positions.ndim == 2:
Expand Down
25 changes: 24 additions & 1 deletion axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chex
import jax.numpy as jnp
import jax.util
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.sharding import PartitionSpec

from axlearn.common import attention_bias, test_utils
Expand Down Expand Up @@ -287,6 +287,25 @@ def test_mask_fn_attention_bias_target_positions_ndim(self):
)
self.assertNestedEqual(bias.bool_value(), expected)

def test_mask_fn_attention_bias_with_target_positions(self):
# Ensure that MaskFnAttentionBias provides the mask_fn callback with target_positions and
# source_positions tensors of the same rank.
batch, target_len, source_len = 2, 5, 4
time_step = jnp.arange(batch)

def mask_fn(target_positions, source_positions):
self.assertEqual(target_positions.shape, (batch, target_len, 1))
self.assertEqual(source_positions.shape, (1, 1, source_len))
return attention_bias.causal_mask(target_positions, source_positions)

bias = attention_bias.MaskFnAttentionBias(
mask=mask_fn, shape=(target_len, source_len), target_positions=time_step
)
ref_bias = attention_bias.MaskFnAttentionBias(
attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step
)
chex.assert_trees_all_close(bias.value(), ref_bias.value())

def test_bool_tensor_attention_bias(self):
bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool))
self.assertNestedEqual(
Expand All @@ -298,3 +317,7 @@ def test_astype(self):
self.assertEqual(bias.value().dtype, jnp.float32)
bias = bias.astype(jnp.bfloat16)
self.assertEqual(bias.value().dtype, jnp.bfloat16)


if __name__ == "__main__":
absltest.main()
Loading