Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable remat checkpoints to host instead of TPU memory #643

Merged
merged 6 commits into from
Aug 14, 2024

Conversation

samos123
Copy link
Contributor

@samos123 samos123 commented Aug 9, 2024

This allowed us to get MFU of fuji v2 70B from 58.50% to 61.83%

@samos123 samos123 marked this pull request as ready for review August 12, 2024 17:53
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -3874,6 +3874,7 @@ def build_remat_spec(
],
self_attention: bool = True,
feed_forward: bool = False,
offload: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a bool, should we allow the caller to customize offload_dst directly?

Suggested change
offload: bool = False,
offload_dst: Optional[Literal["pinned_host"]] = None,

This will make the API more extensible and closer to the JAX API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I will take another stab at this PR with focus on staying closer to JAX API and extensibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved, could you review again?

@samos123 samos123 requested a review from ruomingp August 12, 2024 19:53
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -3891,6 +3904,7 @@ def build_remat_spec(
stack_cfg: A transformer config.
self_attention: Checkpoint self attention layer activations if true.
feed_forward: Checkpoint feed-forward layer activations if true.
offload_dst: Destination of remat checkptoing offloading.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a link to the JAX documentation on offset_dst on the potential values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change pushed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ah I misunderstood Ruoming's comment and thought he was fine there being no docs. Let me add the link to maxtext as a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -188,6 +188,7 @@ def get_trainer_kwargs(
num_kv_heads=None if version == Version.V1 else 8,
rope_theta=rope_theta,
flash_attention=flash_attention,
remat_offload_dst="pinned_host",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment on the observed MFU and step time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -188,6 +188,7 @@ def get_trainer_kwargs(
num_kv_heads=None if version == Version.V1 else 8,
rope_theta=rope_theta,
flash_attention=flash_attention,
remat_offload_dst="pinned_host",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this option limited to 70B? Do we want to apply it to 7B and other models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only use remat checkpoint offload into host when you could benefit from the extra TPU memory.

If there is plenty TPU memory then remat checkpoint into TPU memory.
If TPU memory is low or you want to squeeze as much per device batch as possible, then use offload_dst=pinned_Host.

So yes, it could make a lot of sense for 7B too on V5e and Trilium since then we can possibly increase per device batch size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have this PR focus on enabling remat offload and 70B. As a follow up I can do the following as part of V5E perf benchmarking:

  • Enable remat offload for 7B and compare performance before and after
  • See if I can increase per device batch size for 7B after enabling remat_offload

Would that work for you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks.

@samos123 samos123 requested a review from ruomingp August 12, 2024 21:26
@samos123
Copy link
Contributor Author

@markblee could you trigger CI, review and merge if good?

@markblee markblee enabled auto-merge August 14, 2024 16:49
@markblee markblee added this pull request to the merge queue Aug 14, 2024
Merged via the queue into apple:main with commit a7c64ee Aug 14, 2024
4 checks passed
qdavid1 pushed a commit to qdavid1/axlearn that referenced this pull request Dec 11, 2024
* remat checkpoints to host

* update golden configs

* Change offload to offload_dst

* add step time and MFU for v5e fuji v2 70b

* Add maxtext example in code comment

* Update axlearn/common/attention.py

Co-authored-by: Mark Lee <[email protected]>

---------

Co-authored-by: Mark Lee <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants