diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 235d2d6b5..835022418 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -45,7 +45,19 @@ import functools import math from enum import Enum, unique -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Union, +) import jax from jax import numpy as jnp @@ -3874,6 +3886,7 @@ def build_remat_spec( ], self_attention: bool = True, feed_forward: bool = False, + offload_dst: Optional[Literal["pinned_host"]] = None, ) -> Optional[RematSpec]: """Configures how the Transformer or Conformer stack will save the linearization points. @@ -3891,6 +3904,8 @@ def build_remat_spec( stack_cfg: A transformer config. self_attention: Checkpoint self attention layer activations if true. feed_forward: Checkpoint feed-forward layer activations if true. + offload_dst: Destination of remat checkptoing offloading. Relevant Maxtext example: + https://github.com/google/maxtext/blob/ebd39aa64d670fa13a313b6f776e01ad9e450321/MaxText/layers/models.py#L230. Returns: None (if no rematerialization is needed) or a RematSpec. @@ -3905,17 +3920,27 @@ def build_remat_spec( checkpoints.extend( [f"{attention_name}.{el}" for el in ["q_proj", "k_proj", "v_proj", "context", "o_proj"]] ) + if feed_forward and hasattr(stack_cfg.layer, "feed_forward"): ffn_name = stack_cfg.layer.feed_forward.klass.__name__ checkpoints.extend([f"{ffn_name}.{el}" for el in ["activation", "linear2"]]) + policy = config_for_function(jax_remat_policies.save_only_these_names).set( + names_which_can_be_saved=checkpoints + ) + if offload_dst: + policy = config_for_function(jax_remat_policies.save_and_offload_only_these_names).set( + names_which_can_be_saved=[], + names_which_can_be_offloaded=checkpoints, + offload_src="device", + offload_dst=offload_dst, + ) + return RematSpec( prevent_cse=stack_cfg.klass is StackedTransformerLayer, # If we are running inside a jax.lax.scan (Repeated/Pipelined transformers # or Repeated Conformers) we can enable common subexpression elimination optimizations. - policy=config_for_function(jax_remat_policies.save_only_these_names).set( - names_which_can_be_saved=checkpoints - ), + policy=policy, ) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt index 618c45a96..e4bdba8bb 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash.txt @@ -201,14 +201,16 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt index d82b9a255..c38fa7f2b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1.txt @@ -201,12 +201,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt index b46716f88..7cfeea444 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash.txt @@ -201,14 +201,16 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt index 2c1326abb..e6ed261bc 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2.txt @@ -201,12 +201,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt index 9b87097ae..694a716b4 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash.txt @@ -201,14 +201,16 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'MultiheadAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'MultiheadAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'MultiheadAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'MultiheadAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'MultiheadAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[5]: 'TransformerFeedForwardLayer.activation' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[6]: 'TransformerFeedForwardLayer.linear2' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt index bcb857b08..cc3643aca 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3.txt @@ -201,12 +201,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt index a807cb7d4..b34be1105 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8-stream.txt @@ -194,12 +194,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt index f2b260b9f..b1a938a9a 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/gspmd-16B-2x16x8.txt @@ -194,12 +194,14 @@ model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' model.decoder.transformer.layer.feed_forward.structure: 'prenorm' model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_and_offload_only_these_names' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[0]: 'GroupedQueryAttention.q_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[1]: 'GroupedQueryAttention.k_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[2]: 'GroupedQueryAttention.v_proj' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[3]: 'GroupedQueryAttention.context' +model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_offloaded[4]: 'GroupedQueryAttention.o_proj' +model.decoder.transformer.layer.remat_spec['policy'].offload_dst: 'pinned_host' +model.decoder.transformer.layer.remat_spec['policy'].offload_src: 'device' model.decoder.transformer.layer.self_attention.attention.causal: True model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 42c7561f8..cf7b93ac1 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -11,7 +11,7 @@ """ import math -from typing import Dict, List, Optional, Protocol, Sequence, Tuple, Union +from typing import Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Union import jax.numpy as jnp import tensorflow as tf @@ -183,7 +183,10 @@ def mesh_shape_from_axes( def update_model_remat_config( - *, stack_cfg: causal_lm.TransformerStackConfig, layer_cfg: TransformerLayer.Config + *, + stack_cfg: causal_lm.TransformerStackConfig, + layer_cfg: TransformerLayer.Config, + offload_dst: Optional[Literal["pinned_host"]] = None, ): """Recomputes and sets the remat_spec based on provided layer_cfg. @@ -192,6 +195,7 @@ def update_model_remat_config( Args: stack_cfg: The transformer stack config. layer_cfg: The transformer layer config. + offload_dst: Destination of remat checkptoing offloading. Raises: NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer. @@ -203,10 +207,12 @@ def update_model_remat_config( if layer_cfg.self_attention.attention.klass is not FlashAttention: # Enable remat to reduce memory usage for larger models. - remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg)) + remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg), offload_dst=offload_dst) else: # Checkpointing both ffn and attention to give the best performance. - remat_spec = build_remat_spec(stack_cfg, feed_forward=True, self_attention=True) + remat_spec = build_remat_spec( + stack_cfg, feed_forward=True, self_attention=True, offload_dst=offload_dst + ) layer_cfg.set(remat_spec=remat_spec) @@ -230,6 +236,7 @@ def model_config( ffn_structure: str = "prenorm", atten_structure: str = "prenorm", atten_logit_cap: Optional[float] = None, + remat_offload_dst: Optional[Literal["pinned_host"]] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -258,6 +265,7 @@ def model_config( Options: [prenorm, postnorm, hybridnorm]. atten_logit_cap: Cap the absolute values of logits by tanh. Enabled by setting a positive value. + remat_offload_dst: Destination of remat checkptoing offloading. Returns: A causal LM config. @@ -276,7 +284,9 @@ def model_config( layer_cfg.self_attention.structure = atten_structure layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap if stack_cfg.klass is RepeatedTransformerLayer: - update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) + update_model_remat_config( + stack_cfg=stack_cfg, layer_cfg=layer_cfg, offload_dst=remat_offload_dst + ) # Stack. transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg) decoder_cfg = Decoder.default_config().set( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 9c71d4054..ca887db2d 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -13,7 +13,7 @@ import enum import functools import itertools -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Literal, Optional, Union from axlearn.common import causal_lm, config from axlearn.common.attention import ( @@ -188,6 +188,7 @@ def get_trainer_kwargs( num_kv_heads=None if version == Version.V1 else 8, rope_theta=rope_theta, flash_attention=flash_attention, + remat_offload_dst="pinned_host", ), learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, @@ -195,7 +196,9 @@ def get_trainer_kwargs( max_step=max_step, mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( - # tpu-v5e. step time: TBD. + # TPU V5e maximum per device batch is 1. So need 4 x v5e-256. + # tpu-v5e-512. step time: 14.0817s (61.11% MFU). + # tpu-v5e-1024. step time: 14.3736s (59.87% MFU). ("tpu-v5litepod-256", mesh_shape_from_axes(data=-1, fsdp=256)), # H100/A100 80G. Maximum per-node batch size = 16, hence need >= 64 nodes. # v2 on gpu-p5.48xlarge 8x64, step time: 12.9s. @@ -230,6 +233,7 @@ def model_config( ffn_dim: Optional[Union[int, config.FunctionConfigBase]] = None, flash_attention: bool = False, stack_cfg: Optional[BaseStackedTransformerLayer.Config] = None, + remat_offload_dst: Optional[Literal["pinned_host"]] = None, ) -> causal_lm.Model.Config: """Returns an LM model config based on the given hyperparams. @@ -247,6 +251,7 @@ def model_config( flash_attention: Whether to enable flash attention. stack_cfg: The transformer stack config. If None, defaults to a RepeatedTransformerLayer. + remat_offload_dst: Destination of remat checkptoing offloading. Returns: A causal LM config. @@ -283,6 +288,7 @@ def model_config( emb_cfg=TransformerTextEmbeddings.default_config().set(pos_emb=None), attention_cfg=flash_attention_config() if flash_attention else atten_cfg, attention_qkv_linear=atten_qkv_linear, + remat_offload_dst=remat_offload_dst, ) return cfg