diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 26aceb797..37baf3d8b 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -867,6 +867,7 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: """Computes attention for the given query, key, value. @@ -875,6 +876,12 @@ def forward( See parent class for full docstring. """ + if kv_state is not None: + raise ValueError( + "QKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) + key = query if key is None else key value = query if value is None else value q_proj = self.q_proj(query) @@ -1019,6 +1026,7 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: """Computes multi-head query, key, and value for the input query, key, value @@ -1029,8 +1037,14 @@ def forward( See parent class for full docstring. Raises: - ValueError: If key and value are not both set or both None. + ValueError: If key and value are not both set or both None; or if kv_state is not None. """ + if kv_state is not None: + raise ValueError( + "FusedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) + with child_context("qkv_proj"): params = self.qkv_proj.parameters if key is None and value is None: @@ -1111,12 +1125,18 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, time_step: Optional[Tensor] = None, ) -> FusedQKVLinear.Output: """See FusedQKVLinear for full docstring. N.B. Only supports cases where key and value are both None. """ + if kv_state is not None: + raise ValueError( + "FusedGroupedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) if key is not None or value is not None: raise ValueError("Key and value should be both None.") cfg = self.config @@ -1193,6 +1213,7 @@ def apply_rotary_position_embeddings( key: Tensor, value: Tensor, sinusoidal_pos: Tensor, + rotary_key: bool, rotary_value: bool, ) -> tuple[Tensor, Tensor, Tensor]: """This is a jax implementation (a copy) of the RoPE apply_rotary_position_embeddings. @@ -1205,7 +1226,8 @@ def apply_rotary_position_embeddings( key: Key embeddings with shape [batch_size, seq_len, num_heads, dim]. value: Value embeddings with shape [batch_size, seq_len, num_heads, dim]. sinusoidal_pos: Rotary position embeddings with shape [batch_size, seq_len, 1, dim]. - rotary_value: Whether to apply rotary position embeddings on value layer. + rotary_key: Whether to apply rotary position embeddings on key. + rotary_value: Whether to apply rotary position embeddings on value. Returns: A tuple of: @@ -1226,9 +1248,13 @@ def apply_rotary_position_embeddings( jnp.stack([-query[..., 1::2], query[..., ::2]], axis=-1), query.shape ) query = query * cos_pos + rotate_half_query * sin_pos - # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] - rotate_half_key = jnp.reshape(jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape) - key = key * cos_pos + rotate_half_key * sin_pos + + if rotary_key: + # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2] + rotate_half_key = jnp.reshape( + jnp.stack([-key[..., 1::2], key[..., ::2]], axis=-1), key.shape + ) + key = key * cos_pos + rotate_half_key * sin_pos if rotary_value: # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2] rotate_half_value = jnp.reshape( @@ -1252,6 +1278,7 @@ class Config(BaseQKVLinear.Config): RoFormerSinusoidalPositionalEmbedding.default_config() ) input_linear: BaseQKVLinear.Config = QKVLinear.default_config() + # Whether to apply RoPE rotations to the value embeddings. rotary_value: Required[bool] = REQUIRED def __init__(self, cfg: QKVLinear.Config, *, parent: Module): @@ -1283,23 +1310,27 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[KVState] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: cfg = self.config # Query should have shape of [batch_size, seq_len, num_heads, per_head_dim]. - query, key, value = self.i_proj(query, key=key, value=value) + query, key, value = self.i_proj(query, key=key, value=value, kv_state=kv_state) query_pos = jnp.arange(query.shape[1])[None] # [batch_size=1, seq_len]. if time_step is not None: query_pos = query_pos + time_step[:, None] # [batch_size, seq_len]. sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(query_pos).astype(query.dtype) # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) + + i_proj_computes_kv = kv_state is None query, key, value = apply_rotary_position_embeddings( sinusoidal_pos=sinusoidal_pos_emb, query=query, key=key, value=value, - rotary_value=cfg.rotary_value, + rotary_key=i_proj_computes_kv, + rotary_value=i_proj_computes_kv and cfg.rotary_value, ) return self.Output(query, key, value) diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index f6a292136..d605242bb 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -440,6 +440,7 @@ def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor: x = f(jnp.asarray([1,2]), jnp.asarray([3,4])) assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None] ``` + * Both tensors have the same rank (either 2 or 3), as batch dim is optional. * If given non-scalar arguments of different shapes, the result must be the same if we first broadcast the arguments against each other to make them have the same shape. * Beyond requiring broadcastability, must not impose any constraints on the shapes of its @@ -473,30 +474,47 @@ class MaskFnAttentionBias(BoolAttentionBias): shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False) # The positions in the query sequence that the mask should be computed for. # I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at - # `target_positions[batch, num_heads i]` may attend to. - # If None, set `target_positions[batch, num_heads, i] = i`. - # Shape: [batch]. + # `target_positions[batch, i]` may attend to. + # If None, set `target_positions[batch, i] = i`. + # Shape: [batch] or [batch, target_len]`. # This is typically used during decoding to specify the locations in the sequence being # being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch # entry respectively, we would set `target_positions = jnp.asarray([5, 7])`. + # The motivation for supporting such shapes is for use cases where time_step in transformers + # is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts, + # various papers that need it. target_positions: Optional[Tensor] = None def _bool_value(self) -> Optional[Tensor]: """Return a tensor with the boolean values from `self.mask` before they have been converted to biases. - Shape: - - If `target_positions` is None: [target_len, source_len] - - Else: [batch, target_len, source_len]. + Shape: [batch, target_len, source_len]. + + Raises: + NotImplementedError. If `target_positions.ndim not in [1,2]`. """ target_positions, source_positions = jnp.indices(self.shape, sparse=True) + # Shape: [1, target_len, 1], [1, 1, source_len]. + target_positions, source_positions = target_positions[None], source_positions[None] if self.target_positions is not None: target_positions = self.target_positions + if target_positions.ndim not in [1, 2]: + raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.") if target_positions.ndim == 1: + # Shape: [batch, 1] + [target_len] = [batch, target_len] # pylint: disable-next=unsubscriptable-object target_positions = target_positions[:, None] + jnp.arange(self.shape[0]) - while target_positions.ndim < 3: - target_positions = target_positions[..., None] + elif target_positions.ndim == 2: + shape_with_batch_dim = (1, *self.shape) + # Raise an exception if shapes aren't compatible. We don't use the output. + jnp.broadcast_shapes( + (target_positions.shape[0], 1, target_positions.shape[1]), shape_with_batch_dim + ) + else: + raise NotImplementedError(f"Invalid value {target_positions.ndim=}.") + target_positions = target_positions[..., None] # Shape: [batch, target_len, 1]. + return self.mask(target_positions, source_positions) # pylint: disable=not-callable @classmethod diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py index 0932df100..358c62911 100644 --- a/axlearn/common/attention_bias_test.py +++ b/axlearn/common/attention_bias_test.py @@ -6,7 +6,7 @@ import chex import jax.numpy as jnp import jax.util -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.sharding import PartitionSpec from axlearn.common import attention_bias, test_utils @@ -267,6 +267,45 @@ def test_mask_fn_attention_bias(self): expected = attention_bias.bool_to_bias(expected)[:, None, :] self.assertNestedEqual(bias.value(), expected) + def test_mask_fn_attention_bias_target_positions_ndim(self): + """Tests mask_fn_attention_bias` when `target_positions.ndim == 2.""" + bias = attention_bias.MaskFnAttentionBias( + mask=attention_bias.causal_mask, + shape=(5, 5), + target_positions=jnp.asarray([[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]), + ) + expected = jnp.asarray( + [ + [ + attention_bias.causal_mask(*jnp.indices([5, 5])), + ], + [ + attention_bias.causal_mask(*jnp.indices([5, 5]))[::-1, :], + ], + ], + dtype=bool, + ) + self.assertNestedEqual(bias.bool_value(), expected) + + def test_mask_fn_attention_bias_with_target_positions(self): + # Ensure that MaskFnAttentionBias provides the mask_fn callback with target_positions and + # source_positions tensors of the same rank. + batch, target_len, source_len = 2, 5, 4 + time_step = jnp.arange(batch) + + def mask_fn(target_positions, source_positions): + self.assertEqual(target_positions.shape, (batch, target_len, 1)) + self.assertEqual(source_positions.shape, (1, 1, source_len)) + return attention_bias.causal_mask(target_positions, source_positions) + + bias = attention_bias.MaskFnAttentionBias( + mask=mask_fn, shape=(target_len, source_len), target_positions=time_step + ) + ref_bias = attention_bias.MaskFnAttentionBias( + attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step + ) + chex.assert_trees_all_close(bias.value(), ref_bias.value()) + def test_bool_tensor_attention_bias(self): bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool)) self.assertNestedEqual( @@ -278,3 +317,7 @@ def test_astype(self): self.assertEqual(bias.value().dtype, jnp.float32) bias = bias.astype(jnp.bfloat16) self.assertEqual(bias.value().dtype, jnp.bfloat16) + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 5d4aeb623..1e188ecc0 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -737,18 +737,24 @@ def test_alibi_attention_mask(self): class RoFormerSinusoidalPositionalEmbeddingTest(TestCase): """Tests RoFormerSinusoidalPositionalEmbedding.""" - @parameterized.parameters( - (2, 3, 10, 32, True), - (2, 3, 8, 32, False), - (2, 4, 6, 32, True), - (2, 4, 8, 16, False), - (2, 5, 8, 48, True), - (2, 5, 8, 64, False), + @parameterized.product( + tensor_dimensions=( + (2, 3, 10, 32), + (2, 3, 8, 32), + (2, 4, 6, 32), + (2, 4, 8, 16), + (2, 5, 8, 48), + (2, 5, 8, 64), + ), + rotary_key=(True, False), + rotary_value=(True, False), ) def test_apply_rotary_position_embeddings( - self, batch_size, num_heads, max_len, dim, rotary_value + self, tensor_dimensions: tuple[int, int, int, int], rotary_key: bool, rotary_value: bool ): # Unittest against the apply_rotary_position_embeddings in HF. + batch_size, num_heads, max_len, dim = tensor_dimensions + token_ids = np.random.randint(low=1, high=20, size=[batch_size, max_len]) sinusoidal_pos_layer = hf_roformer.RoFormerSinusoidalPositionalEmbedding(max_len, dim) sinusoidal_pos = sinusoidal_pos_layer(as_torch_tensor(token_ids).shape)[None, None, :, :] @@ -771,11 +777,15 @@ def test_apply_rotary_position_embeddings( sinusoidal_pos, as_torch_tensor(query), as_torch_tensor(key) ) ref_v_proj = as_torch_tensor(value) + if not rotary_key: + ref_k_proj = as_torch_tensor(key) + test_q_proj, test_k_proj, test_v_proj = test_layer( sinusoidal_pos=as_tensor(sinusoidal_pos), query=query, key=key, value=value, + rotary_key=rotary_key, rotary_value=rotary_value, ) np.testing.assert_allclose(test_q_proj, ref_q_proj, atol=5e-7) @@ -1128,6 +1138,7 @@ def test_against_llama_for_apply_rotary_emb(self): key=jnp.asarray(key), value=jnp.asarray(value), sinusoidal_pos=axlearn_rope, + rotary_key=True, rotary_value=False, ) @@ -1382,11 +1393,22 @@ def test_num_kv_heads( layer = cfg.instantiate(parent=None) self.assertEqual(expected, layer.num_kv_heads) - def test_qlinear(self): + @parameterized.parameters( + (QKVLinear.default_config(), QLinear.default_config()), + ( + RoFormerQKVLinear.default_config().set( + input_linear=QKVLinear.default_config(), rotary_value=False + ), + RoFormerQKVLinear.default_config().set( + input_linear=QLinear.default_config(), rotary_value=False + ), + ), + ) + def test_qlinear(self, base_cfg, test_cfg): """Tests that QLinear is equivalent to QKVLinear with the same kv_state.""" with utils.numeric_checks(True): model_dim = 12 - num_heads = 4 + num_heads = 3 per_head_dim = model_dim // num_heads layer_kwargs = dict( query_dim=model_dim, @@ -1395,8 +1417,8 @@ def test_qlinear(self): num_heads=num_heads, per_head_dim=per_head_dim, ) - base_cfg = QKVLinear.default_config().set(**layer_kwargs) - test_cfg = QLinear.default_config().set(**layer_kwargs) + base_cfg = base_cfg.set(**layer_kwargs) + test_cfg = test_cfg.set(**layer_kwargs) maybe_set_config(test_cfg, num_kv_heads=num_heads) base_layer = base_cfg.set(name="base").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) @@ -1404,7 +1426,12 @@ def test_qlinear(self): # Construct base layer state. base_state = base_layer.initialize_parameters_recursively(jax.random.PRNGKey(0)) # Map state to QLinear. - test_state = {"q_proj": base_state["q_proj"]} + if "q_proj" in base_state: + test_state = {"q_proj": base_state["q_proj"]} + elif "i_proj" in base_state: + test_state = {"i_proj": {"q_proj": base_state["i_proj"]["q_proj"]}} + else: + raise ValueError("Cannot find expected q_proj state.") # Construct test inputs. batch_size, src_len, tgt_len = 2, 6, 6 diff --git a/axlearn/common/lora.py b/axlearn/common/lora.py index b968f1548..199cef603 100644 --- a/axlearn/common/lora.py +++ b/axlearn/common/lora.py @@ -516,8 +516,15 @@ def forward( *, key: Optional[Tensor] = None, value: Optional[Tensor] = None, + kv_state: Optional[Tensor] = None, time_step: Optional[Tensor] = None, ) -> BaseQKVLinear.Output: + if kv_state is not None: + raise ValueError( + "LoraFusedQKVLinear computes key and value projections " + "and does not expect external `kv_state`." + ) + cfg = self.config if key is None and value is None: inputs = query