Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuron support in Axlearn #566

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Address PR comments
  • Loading branch information
apoorvtintin committed Jul 24, 2024
commit 49f9efa2aa149c6355b2474328e13292b2f6e1ed
4 changes: 2 additions & 2 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ def create_device_mesh(
# Check if the devices are part of a multi-granule configuration.
# <https://github.com/google/jax/blob/b81b79c1b0d2ec/jax/experimental/mesh_utils.py#L313>
device_platform = devices[0].platform
attr = "process_index" if device_platform != "tpu" else "slice_index"
attr = "process_index" if device_platform == "gpu" else "slice_index"
is_multi_granule_env = hasattr(devices[0], attr)
if not all(el.platform == device_platform for el in devices):
raise NotImplementedError(f"Not all devices had platform: {device_platform}.")
Expand All @@ -1193,7 +1193,7 @@ def create_device_mesh(
logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.")
return build_standard_mesh(mesh_shape, devices=devices)

# Neuron also only uses standard mesh
# Neuron also only uses standard mesh
if device_platform == "neuron":
return build_standard_mesh(mesh_shape, devices=devices)

Expand Down
6 changes: 3 additions & 3 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
"""

import math
import numpy as np
from typing import Dict, List, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from jax.sharding import PartitionSpec

Expand Down Expand Up @@ -271,8 +271,8 @@ def model_config(
)

device_platform = np.asarray(jax.devices())[0].platform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.devices() during config building may be an unexpected dependency on global state -- should we take a platform arg or similar?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it, but I followed the pattern already used here

devices = jax.devices()

Please let me know if the platform flag is necessary, I can add it. Thanks!

# neuron uses Zero 3
fsdp_axis_names = ("expert", "fsdp", "seq") if device_platform != 'neuron' else ("data", "expert", "fsdp", "seq")
# Trainium will have FSDP support soon, for now use Zero 3.
fsdp_axis_names = ("expert", "fsdp", "seq") if device_platform != "neuron" else ("data")

cfg.dtype = jnp.float32
# Shard some FFN and attention weights over multiple axes.
Expand Down
5 changes: 2 additions & 3 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class Version(enum.Enum):
},
}

TRN_MODEL_AXIS_SIZE=8

def get_trainer_kwargs(
model_size: str,
Expand Down Expand Up @@ -167,9 +166,9 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
(
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)",
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE),
mesh_shape_from_axes(data=-1, model=8),
),
),
)
Expand Down