Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable remat checkpoints to host instead of TPU memory #643

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
remat checkpoints to host
  • Loading branch information
samos123 committed Aug 9, 2024
commit 5e7786a3d2197ea6f0ea52b2b3d8810f6cf9a6d2
18 changes: 15 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3874,6 +3874,7 @@ def build_remat_spec(
],
self_attention: bool = True,
feed_forward: bool = False,
offload: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a bool, should we allow the caller to customize offload_dst directly?

Suggested change
offload: bool = False,
offload_dst: Optional[Literal["pinned_host"]] = None,

This will make the API more extensible and closer to the JAX API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I will take another stab at this PR with focus on staying closer to JAX API and extensibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved, could you review again?

) -> Optional[RematSpec]:
"""Configures how the Transformer or Conformer stack will save the linearization points.

Expand All @@ -3891,6 +3892,7 @@ def build_remat_spec(
stack_cfg: A transformer config.
self_attention: Checkpoint self attention layer activations if true.
feed_forward: Checkpoint feed-forward layer activations if true.
offload: Offload the checkpoints to host memory instead of TPU memory.

Returns:
None (if no rematerialization is needed) or a RematSpec.
Expand All @@ -3905,17 +3907,27 @@ def build_remat_spec(
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)

if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])

policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
)
if offload:
policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set(
names_which_can_be_saved=[],
names_which_can_be_offloaded=checkpoints,
offload_src="device",
offload_dst="pinned_host",
)

return RematSpec(
prevent_cse=stack_cfg.klass is StackedTransformerLayer,
# If we are running inside a jax.lax.scan (Repeated/Pipelined transformers
# or Repeated Conformers) we can enable common subexpression elimination optimizations.
policy=config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
),
policy=policy,
)


Expand Down
15 changes: 11 additions & 4 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ def mesh_shape_from_axes(


def update_model_remat_config(
*, stack_cfg: causal_lm.TransformerStackConfig, layer_cfg: TransformerLayer.Config
*,
stack_cfg: causal_lm.TransformerStackConfig,
layer_cfg: TransformerLayer.Config,
offload: bool = False,
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.

Expand All @@ -203,10 +206,12 @@ def update_model_remat_config(

if layer_cfg.self_attention.attention.klass is not FlashAttention:
# Enable remat to reduce memory usage for larger models.
remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg), offload=offload)
else:
# Checkpointing both ffn and attention to give the best performance.
remat_spec = build_remat_spec(stack_cfg, feed_forward=True, self_attention=True)
remat_spec = build_remat_spec(
stack_cfg, feed_forward=True, self_attention=True, offload=offload
)
layer_cfg.set(remat_spec=remat_spec)


Expand All @@ -230,6 +235,7 @@ def model_config(
ffn_structure: str = "prenorm",
atten_structure: str = "prenorm",
atten_logit_cap: Optional[float] = None,
remat_offload: bool = False,
) -> causal_lm.Model.Config:
"""Returns an LM model config based on the given hyperparams.

Expand Down Expand Up @@ -258,6 +264,7 @@ def model_config(
Options: [prenorm, postnorm, hybridnorm].
atten_logit_cap: Cap the absolute values of logits by tanh.
Enabled by setting a positive value.
remat_offload: Offload remat checkpoints to host instead of TPU memory.

Returns:
A causal LM config.
Expand All @@ -276,7 +283,7 @@ def model_config(
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg, offload=remat_offload)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
decoder_cfg = Decoder.default_config().set(
Expand Down
4 changes: 4 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def get_trainer_kwargs(
num_kv_heads=None if version == Version.V1 else 8,
rope_theta=rope_theta,
flash_attention=flash_attention,
remat_offload=True,
),
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
max_sequence_length=max_sequence_length,
Expand Down Expand Up @@ -230,6 +231,7 @@ def model_config(
ffn_dim: Optional[Union[int, config.FunctionConfigBase]] = None,
flash_attention: bool = False,
stack_cfg: Optional[BaseStackedTransformerLayer.Config] = None,
remat_offload: bool = False,
) -> causal_lm.Model.Config:
"""Returns an LM model config based on the given hyperparams.

Expand All @@ -247,6 +249,7 @@ def model_config(
flash_attention: Whether to enable flash attention.
stack_cfg: The transformer stack config.
If None, defaults to a RepeatedTransformerLayer.
remat_offload: Offload remat checkpoints to host instead of TPU memory.

Returns:
A causal LM config.
Expand Down Expand Up @@ -283,6 +286,7 @@ def model_config(
emb_cfg=TransformerTextEmbeddings.default_config().set(pos_emb=None),
attention_cfg=flash_attention_config() if flash_attention else atten_cfg,
attention_qkv_linear=atten_qkv_linear,
remat_offload=remat_offload,
)
return cfg

Expand Down