Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neuron support in Axlearn #566

Closed
wants to merge 4 commits into from

Conversation

apoorvtintin
Copy link
Contributor

This PR enables use of neuron devices in Axlearn for model training.

  • Chooses correct mesh for TRN devices for Fuji 7B with the mesh selector flag --mesh_selector=neuron-trn1.32xlarge-64

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

axlearn/common/utils.py Outdated Show resolved Hide resolved
@@ -167,6 +167,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)",
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model=8 compare to fsdp=8? Usually we find fsdp to be more efficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be worth listing the step times for different configurations, similar to the other mesh rules.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does model=8 compare to fsdp=8? Usually we find fsdp to be more efficient.

I am launching a fsdp=8 job with 8 nodes. The job is blocked due to AWS capacity. Hope to get some data to share by Friday

The previous response from AWS was that FSDP is slower due to higher communication overhead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor parallel (model) is more performant on trn1 arch

@@ -167,6 +167,10 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)",
mesh_shape_from_axes(data=-1, fsdp=8),
),
(
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)",
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be worth listing the step times for different configurations, similar to the other mesh rules.

axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
axlearn/experiments/text/gpt/common.py Outdated Show resolved Hide resolved
@@ -267,12 +269,17 @@ def model_config(
batch_axis_names=batch_axis_names,
seq_axis_names="seq",
)

device_platform = np.asarray(jax.devices())[0].platform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.devices() during config building may be an unexpected dependency on global state -- should we take a platform arg or similar?

Copy link
Contributor Author

@apoorvtintin apoorvtintin Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it, but I followed the pattern already used here

devices = jax.devices()

Please let me know if the platform flag is necessary, I can add it. Thanks!

@kelvin-zou
Copy link
Contributor

@apoorvtintin I see this PR is quite stale for sometime.
If no objection, I'd like to have @Ruixuan who is working on Trn from our end to port your change and continue iterate it?

@patrick-toulme
Copy link
Contributor

@apoorvtintin I see this PR is quite stale for sometime. If no objection, I'd like to have @Ruixuan who is working on Trn from our end to port your change and continue iterate it?

Apoorv is on PTO right now. I am OK with you all taking over this PR. Can you add us as a reviewer when you finish? Thanks

@apoorvtintin
Copy link
Contributor Author

Thanks for all the reviews, I fixed most of the comments on the PR.

@kelvin-zou
Copy link
Contributor

Is this PR still needed?

@apoorvtintin
Copy link
Contributor Author

Not needed, closing this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants