-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support customized mesh rules to support different HWs #696
Conversation
axlearn/common/utils.py
Outdated
@@ -102,6 +105,26 @@ def sharding(self) -> jax.sharding.Sharding: | |||
NestedTensorSpec = Optional[Union[TensorSpec, dict[str, Any]]] | |||
|
|||
|
|||
def offload_dots_saveble(offload_src, offload_dst): | |||
"""Extract and combine the policy from save_and_offload_only_these_names and dots_saveable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this accurate? Seems that dots_saveable
includes lax_convolution.conv_general_dilated_p
. Also clarify how it combines save_and_offload_only_these_names
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated comments, the reference is actually a bit different from dots_saveable.
@@ -0,0 +1,180 @@ | |||
# Copyright © 2023 Apple Inc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright © 2023 Apple Inc. | |
# Copyright © 2024 Apple Inc. |
(here and elsewhere)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the review, I need one more approval after the nit change.
* cherrypick from internal PR change * snapshot * nit comment fixes * address Mark's comments * snapshot * fix the wrong name * snapshot * nit type format * snasphot * snasphot * remove white space * nit format fix * fix some minor annotation
Add a few system only mesh config modifiers, including
Also added v5e fuji-7B numbers for the record.