-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Neuron support in Axlearn #566
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
axlearn/experiments/text/gpt/fuji.py
Outdated
@@ -167,6 +167,10 @@ def get_trainer_kwargs( | |||
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)", | |||
mesh_shape_from_axes(data=-1, fsdp=8), | |||
), | |||
( | |||
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)", | |||
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does model=8
compare to fsdp=8
? Usually we find fsdp to be more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might also be worth listing the step times for different configurations, similar to the other mesh rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does
model=8
compare tofsdp=8
? Usually we find fsdp to be more efficient.
I am launching a fsdp=8 job with 8 nodes. The job is blocked due to AWS capacity. Hope to get some data to share by Friday
The previous response from AWS was that FSDP is slower due to higher communication overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor parallel (model) is more performant on trn1 arch
axlearn/experiments/text/gpt/fuji.py
Outdated
@@ -167,6 +167,10 @@ def get_trainer_kwargs( | |||
"gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)", | |||
mesh_shape_from_axes(data=-1, fsdp=8), | |||
), | |||
( | |||
"neuron-(trn1.32xlarge|trn1n.32xlarge)-(32|64|256|512|1024|2048)", | |||
mesh_shape_from_axes(data=-1, model=TRN_MODEL_AXIS_SIZE), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might also be worth listing the step times for different configurations, similar to the other mesh rules.
@@ -267,12 +269,17 @@ def model_config( | |||
batch_axis_names=batch_axis_names, | |||
seq_axis_names="seq", | |||
) | |||
|
|||
device_platform = np.asarray(jax.devices())[0].platform |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.devices()
during config building may be an unexpected dependency on global state -- should we take a platform
arg or similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could change it, but I followed the pattern already used here
axlearn/axlearn/common/utils.py
Line 1231 in 89c6f75
devices = jax.devices() |
Please let me know if the platform flag is necessary, I can add it. Thanks!
@apoorvtintin I see this PR is quite stale for sometime. |
Apoorv is on PTO right now. I am OK with you all taking over this PR. Can you add us as a reviewer when you finish? Thanks |
Thanks for all the reviews, I fixed most of the comments on the PR. |
Is this PR still needed? |
Not needed, closing this |
This PR enables use of neuron devices in Axlearn for model training.
--mesh_selector=neuron-trn1.32xlarge-64