Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable remat checkpoints to host instead of TPU memory #643

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,19 @@
import functools
import math
from enum import Enum, unique
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
)

import jax
from jax import numpy as jnp
Expand Down Expand Up @@ -3874,6 +3886,7 @@ def build_remat_spec(
],
self_attention: bool = True,
feed_forward: bool = False,
offload_dst: Optional[Literal["pinned_host"]] = None,
) -> Optional[RematSpec]:
"""Configures how the Transformer or Conformer stack will save the linearization points.

Expand All @@ -3891,6 +3904,8 @@ def build_remat_spec(
stack_cfg: A transformer config.
self_attention: Checkpoint self attention layer activations if true.
feed_forward: Checkpoint feed-forward layer activations if true.
offload_dst: Destination of remat checkptoing offloading. Relevant Maxtext example:
https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L230.

Returns:
None (if no rematerialization is needed) or a RematSpec.
Expand All @@ -3905,17 +3920,27 @@ def build_remat_spec(
checkpoints.extend(
[f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]]
)

if feed_forward and hasattr(stack_cfg.layer, "feed_forward"):
ffn_name = stack_cfg.layer.feed_forward.klass.__name__
checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]])

policy = config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
)
if offload_dst:
policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set(
names_which_can_be_saved=[],
names_which_can_be_offloaded=checkpoints,
offload_src="device",
offload_dst=offload_dst,
)

return RematSpec(
prevent_cse=stack_cfg.klass is StackedTransformerLayer,
# If we are running inside a jax.lax.scan (Repeated/Pipelined transformers
# or Repeated Conformers) we can enable common subexpression elimination optimizations.
policy=config_for_function(jax_remat_policies.save_only_these_names).set(
names_which_can_be_saved=checkpoints
),
policy=policy,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,16 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[5]: 'TransformerFeedForwardLayer.activation'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[6]: 'TransformerFeedForwardLayer.linear2'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,16 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[5]: 'TransformerFeedForwardLayer.activation'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[6]: 'TransformerFeedForwardLayer.linear2'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,16 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[5]: 'TransformerFeedForwardLayer.activation'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[6]: 'TransformerFeedForwardLayer.linear2'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: False
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context'
model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj'
model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host'
model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
Expand Down
Loading