Skip to content

Commit

Permalink
Fix feedforward remat point (apple#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 authored and qdavid1 committed Dec 11, 2024
1 parent 72dc460 commit 62e8e29
Show file tree
Hide file tree
Showing 63 changed files with 445 additions and 461 deletions.
83 changes: 49 additions & 34 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@
import enum
import functools
import math
import re
from collections.abc import Sequence
from enum import Enum, unique
from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union

import einops
import jax
from jax import numpy as jnp
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
from jax._src.ad_checkpoint import name_p
from jax._src.interpreters import partial_eval as pe
from jax.core import Primitive

from axlearn.common import ops, param_init
from axlearn.common.base_layer import (
Expand Down Expand Up @@ -2939,12 +2942,10 @@ def _linear2(x):

self._add_tensor_stats("inputs", inputs)

remat_pt1 = "activation"
remat_pt2 = "linear2"
if cfg.structure == "prenorm":
x = self.norm(inputs)
x = self._linear1_activation(x)
x = self._remat_name(x, remat_pt1)
x = self.dropout1(x)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
Expand All @@ -2955,7 +2956,6 @@ def _linear2(x):
x += inputs
elif cfg.structure == "postnorm":
x = self._linear1_activation(inputs)
x = self._remat_name(x, remat_pt1)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
x = self.dropout(x)
Expand All @@ -2966,7 +2966,6 @@ def _linear2(x):
elif cfg.structure == "hybridnorm":
x = self.prenorm(inputs)
x = self._linear1_activation(x)
x = self._remat_name(x, remat_pt1)
x = self.dropout1(x)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
Expand All @@ -2979,7 +2978,6 @@ def _linear2(x):
elif cfg.structure == "nonorm":
x = inputs
x = self._linear1_activation(x)
x = self._remat_name(x, remat_pt1)
x = self.dropout1(x)
x = _linear2(x)
x = self._remat_name(x, remat_pt2)
Expand All @@ -2998,7 +2996,8 @@ def _linear1_activation(self, x: Tensor) -> Tensor:
if isinstance(cfg.activation, tuple):
activations = [
self._get_activation(
self.children[f"linear1_{i}"](x), activation_fn_name=activation
self._remat_name(self.children[f"linear1_{i}"](x), f"linear1_{i}"),
activation_fn_name=activation,
)
for i, activation in enumerate(cfg.activation)
]
Expand All @@ -3010,6 +3009,7 @@ def _linear1_activation(self, x: Tensor) -> Tensor:
return outputs
else:
x = self.linear1(x)
x = self._remat_name(x, "linear1_0")
x = self._get_activation(x, activation_fn_name=cfg.activation)
self._add_tensor_stats("linear1_outputs", x)
return x
Expand Down Expand Up @@ -4072,13 +4072,43 @@ def forward(
# TODO(sneha): extend_step


OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]]
_SavePattern = Union[str, re.Pattern, None]


# Adapted from jax source code to support regex. Reference:
# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120
def _save_and_offload_only_these_names_regex(
*,
names_which_can_be_saved: _SavePattern,
names_which_can_be_offloaded: _SavePattern,
offload_src: str,
offload_dst: str,
) -> OffloadPolicy:
def policy(prim, *_, **params):
if prim is name_p:
if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]):
return pe.Saveable
if names_which_can_be_offloaded and re.fullmatch(
names_which_can_be_offloaded, params["name"]
):
return pe.Offloadable(src=offload_src, dst=offload_dst)
return pe.Recompute # not saveable unless it's in the allow-list

return policy


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
self_attention: bool = True,
feed_forward: bool = False,
offload_dst: Optional[Literal["pinned_host"]] = None,
save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN,
offload_pattern: _SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
"""Configures how the Transformer or Conformer stack will save the linearization points.
Expand All @@ -4094,10 +4124,10 @@ def build_remat_spec(
Args:
stack_cfg: A transformer config.
self_attention: Checkpoint self attention layer activations if true.
feed_forward: Checkpoint feed-forward layer activations if true.
save_pattern: Activation regex pattern to save in HBM.
offload_pattern: Activation regex pattern to offload to `offload_dst`.
offload_dst: Destination of remat checkptoing offloading. Relevant Maxtext example:
https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L230.
https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L230.
Returns:
None (if no rematerialization is needed) or a RematSpec.
Expand All @@ -4106,27 +4136,12 @@ def build_remat_spec(
if stack_cfg.klass is PipelinedTransformerLayer:
return None

checkpoints = []
if self_attention:
attention_name = stack_cfg.layer.self_attention.attention.klass.__name__
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)

if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])

policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
policy = config_for_function(_save_and_offload_only_these_names_regex).set(
names_which_can_be_saved=save_pattern,
names_which_can_be_offloaded=offload_pattern,
offload_src="device",
offload_dst=offload_dst,
)
if offload_dst:
policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set(
names_which_can_be_saved=[],
names_which_can_be_offloaded=checkpoints,
offload_src="device",
offload_dst=offload_dst,
)

return RematSpec(
prevent_cse=stack_cfg.klass is StackedTransformerLayer,
Expand Down
106 changes: 99 additions & 7 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.

"""Tests attention layers."""

import contextlib
import copy
import itertools
Expand All @@ -23,6 +22,7 @@
from collections.abc import Sequence
from itertools import combinations
from typing import Any, Callable, Optional, Union
from unittest import mock

import jax
import numpy as np
Expand All @@ -40,6 +40,7 @@

from axlearn.common import attention, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
NEG_INF,
BaseStackedTransformerLayer,
BaseTransformerLayer,
Expand All @@ -65,6 +66,7 @@
TransformerFeedForwardLayer,
TransformerLayer,
_next_power_of_two,
_save_and_offload_only_these_names_regex,
apply_attention_logit_biases,
apply_rotary_position_embeddings,
bool_to_bias,
Expand Down Expand Up @@ -3405,6 +3407,53 @@ def test_add_dead_neuron_summary(self, activation_fn: Union[str, list[str]]):
},
)

def test_linear_remat(self):
batch, seq_len, dim = 2, 3, 4
cfg = TransformerFeedForwardLayer.default_config().set(
name="ffn",
input_dim=dim,
hidden_dim=dim * 4,
add_value_rms_norm_summary=[],
tensor_stats=DefaultTensorStats.default_config(),
activation=("nn.relu", "nn.relu"),
)
layer = cfg.instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))
x = jax.random.normal(jax.random.PRNGKey(1), shape=[batch, seq_len, dim])

def f(x, layer_params):
y, _ = F(
layer,
inputs=dict(inputs=x),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return y

_, save_name_backward = jax.linearize(
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
),
),
x,
layer_params,
)
_, save_dots_backward = jax.linearize(
jax.remat(f, policy=jax_remat_policies.dots_saveable), x, layer_params
)

self.assertEqual(str(save_name_backward).count(" dot_general"), 6)
self.assertEqual(
str(save_name_backward).count(" dot_general"),
str(save_dots_backward).count(" dot_general"),
)


class BaseTransformerTest(TestCase):
def _test_decoder_with_transformer(self, transformer_cfg: BaseTransformerLayer.Config):
Expand Down Expand Up @@ -3794,6 +3843,53 @@ def test_with_golden_value(self):
self.assertEqual(target.shape, layer_outputs.data.shape)
self.assertNestedAllClose(0.609666, np.mean(layer_outputs.data))

def test_build_remat_spec(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=spec.policy.instantiate(), prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)
# Eliminated the remat of qkv_proj, context and o_proj = 5 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
5,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down Expand Up @@ -4757,9 +4853,7 @@ def test_repeated_layer_with_custom_carry(self, repeat_carry, precomputed_kv_sta
output_self_attention_kv_state=True,
)
cfg.stack.repeat.carry = repeat_carry
cfg.stack.layer.remat_spec = build_remat_spec(
cfg.stack, self_attention=True, feed_forward=True
)
cfg.stack.layer.remat_spec = build_remat_spec(cfg.stack)
if precomputed_kv_state:
kv_shape = (batch_size, seq_len, num_heads, head_dim)
kv_state = KVState(
Expand Down Expand Up @@ -4845,9 +4939,7 @@ def test_initialize_parameters_recursively(self, prebuilt_layers: list[str]):
remat_spec=None,
output_self_attention_kv_state=True,
)
cfg.stack.layer.remat_spec = build_remat_spec(
cfg.stack, self_attention=True, feed_forward=True
)
cfg.stack.layer.remat_spec = build_remat_spec(cfg.stack)
layer = cfg.instantiate(parent=None)
param_specs = layer.create_parameter_specs_recursively()
initialized_from_scratch = layer.initialize_parameters_recursively(
Expand Down
14 changes: 4 additions & 10 deletions axlearn/common/conformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,10 @@ def test_repeated_conformer_config(self, layer_order):
self.assertEqual(cfg.layer.self_attention.attention.input_linear.layer.bias, True)

@parameterized.product(
checkpoint_self_attention=(True, False),
checkpoint_feed_forward=(True, False),
test_remat=(True, False),
layer_order=(None, "lconv_before_ff", "lconv_before_mhsa", "mhsa_before_lconv"),
)
def test_repeated_conformer_forward(
self, checkpoint_self_attention, checkpoint_feed_forward, layer_order
):
def test_repeated_conformer_forward(self, test_remat, layer_order):
"""Tests RepeatedConformerLayer."""
dim, num_heads = 6, 2
# Create a conformer layer.
Expand All @@ -189,11 +186,8 @@ def test_repeated_conformer_forward(
)
repeat_cfg.layer.layer_order = layer_order
repeat_cfg.layer.self_attention.attention.num_heads = num_heads
repeat_cfg.layer.remat_spec = build_remat_spec(
repeat_cfg,
self_attention=checkpoint_self_attention,
feed_forward=checkpoint_feed_forward,
)
if test_remat:
repeat_cfg.layer.remat_spec = build_remat_spec(repeat_cfg)
repeat_layer = repeat_cfg.instantiate(parent=None) # type: RepeatedConformerLayer
repeat_state = repeat_layer.initialize_parameters_recursively(jax.random.PRNGKey(100))
# Generate synthetic inputs.
Expand Down
4 changes: 2 additions & 2 deletions axlearn/common/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor)
A tensor with shape [batch_size, num_length, input_dim].
"""
cfg = self.config
remat_pt1 = "activation"
remat_pt1 = "linear1_0"
remat_pt2 = "linear2"

if cfg.structure == "prenorm":
Expand All @@ -312,8 +312,8 @@ def forward(self, *, input: Tensor, shift: Tensor, scale: Tensor, gate: Tensor)

x = modulate(x=x, shift=shift, scale=scale)
x = self.linear1(x)
x = get_activation_fn(cfg.activation)(x)
x = self._remat_name(x, remat_pt1)
x = get_activation_fn(cfg.activation)(x)
x = self.dropout1(x)
x = self.linear2(x)
x = self._remat_name(x, remat_pt2)
Expand Down
Loading

0 comments on commit 62e8e29

Please sign in to comment.