From c8d0a7d16e79b0f8a9e084c7ce74dc887e71fe63 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 1 Jul 2024 19:14:46 +0000 Subject: [PATCH 1/4] General Neuron Support --- axlearn/common/utils.py | 4 ++++ axlearn/experiments/text/gpt/fuji.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 2e446945b..212c68a84 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1193,6 +1193,10 @@ def create_device_mesh( logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.") return build_standard_mesh(mesh_shape, devices=devices) + # Neuron also only uses standard mesh + if device_platform == "neuron": + return build_standard_mesh(mesh_shape, devices=devices) + # We only break the first device axis (the least communication intensive) across granules. assert ( ici_mesh_shape[0] % num_granules == 0 diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 1e21c1eaa..dbe126fc4 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -83,6 +83,7 @@ class Version(enum.Enum): }, } +TRN_MODEL_AXIS_SIZE=8 def get_trainer_kwargs( model_size: str, @@ -103,7 +104,6 @@ def get_trainer_kwargs( num_kv_heads = 8 rope_theta = ROPE_THETA[version] - # dict() is more readable here. # pylint: disable=use-dict-literal if model_size == "test": @@ -167,6 +167,10 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), + ( + "neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)", + mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), + ), ), ) elif model_size == "70B": From 6669a41090f2882ee3364bc9c0d5489e7adfd6a6 Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 1 Jul 2024 19:42:10 +0000 Subject: [PATCH 2/4] add 'data' axis to fsdp axis --- axlearn/experiments/text/gpt/common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index db0c0be9e..2841eb03e 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -267,12 +267,17 @@ def model_config( batch_axis_names=batch_axis_names, seq_axis_names="seq", ) + + device_platform = np.asarray(jax.devices())[0].platform + # neuron uses Zero 3 + fsdp_axis_names = ("expert", "fsdp", "seq") if device_platform != 'neuron' else ("data", "expert", "fsdp", "seq") + cfg.dtype = jnp.float32 # Shard some FFN and attention weights over multiple axes. set_double_shard_weights_config( cfg.decoder.transformer.layer, batch_axis_names=batch_axis_names, - fsdp_axis_names=("expert", "fsdp", "seq"), + fsdp_axis_names=fsdp_axis_names, tp_axis_names="model", seq_axis_names=("seq",), ) From 7baa4d96271cbeb75bf8ce668acac0efcd66451d Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Mon, 1 Jul 2024 20:07:31 +0000 Subject: [PATCH 3/4] fix import --- axlearn/experiments/text/gpt/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 2841eb03e..060791604 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -11,8 +11,10 @@ """ import math +import numpy as np from typing import Dict, List, Optional, Sequence, Tuple, Union +import jax import jax.numpy as jnp import tensorflow as tf from jax.sharding import PartitionSpec From 49f9efa2aa149c6355b2474328e13292b2f6e1ed Mon Sep 17 00:00:00 2001 From: Apoorv Gupta Date: Wed, 24 Jul 2024 23:08:43 +0000 Subject: [PATCH 4/4] Address PR comments --- axlearn/common/utils.py | 4 ++-- axlearn/experiments/text/gpt/common.py | 6 +++--- axlearn/experiments/text/gpt/fuji.py | 5 ++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 212c68a84..afdf01463 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1176,7 +1176,7 @@ def create_device_mesh( # Check if the devices are part of a multi-granule configuration. # device_platform = devices[0].platform - attr = "process_index" if device_platform != "tpu" else "slice_index" + attr = "process_index" if device_platform == "gpu" else "slice_index" is_multi_granule_env = hasattr(devices[0], attr) if not all(el.platform == device_platform for el in devices): raise NotImplementedError(f"Not all devices had platform: {device_platform}.") @@ -1193,7 +1193,7 @@ def create_device_mesh( logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.") return build_standard_mesh(mesh_shape, devices=devices) - # Neuron also only uses standard mesh + # Neuron also only uses standard mesh if device_platform == "neuron": return build_standard_mesh(mesh_shape, devices=devices) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 060791604..b5190b97c 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -11,11 +11,11 @@ """ import math -import numpy as np from typing import Dict, List, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp +import numpy as np import tensorflow as tf from jax.sharding import PartitionSpec @@ -271,8 +271,8 @@ def model_config( ) device_platform = np.asarray(jax.devices())[0].platform - # neuron uses Zero 3 - fsdp_axis_names = ("expert", "fsdp", "seq") if device_platform != 'neuron' else ("data", "expert", "fsdp", "seq") + # Trainium will have FSDP support soon, for now use Zero 3. + fsdp_axis_names = ("expert", "fsdp", "seq") if device_platform != "neuron" else ("data") cfg.dtype = jnp.float32 # Shard some FFN and attention weights over multiple axes. diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index dbe126fc4..03ac5ddd7 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -83,7 +83,6 @@ class Version(enum.Enum): }, } -TRN_MODEL_AXIS_SIZE=8 def get_trainer_kwargs( model_size: str, @@ -167,9 +166,9 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)", mesh_shape_from_axes(data=-1, fsdp=8), ), - ( + ( "neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)", - mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), + mesh_shape_from_axes(data=-1, model=8), ), ), )