-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Regression] Gradient explodes after upgrading to JAX 0.4.33 from 0.4.30 #17922
[Regression] Gradient explodes after upgrading to JAX 0.4.33 from 0.4.30 #17922
Comments
New JAX 0.4.34 just have been released. |
Here's dumped HLOs of compiled training step for hidden_dim=20704 (gradients are not exploding) and hidden_dim=20720 (gradients are exploding). |
Can you try setting the environment variable XLA_FLAGS=--xla_gpu_enable_dynamic_slice_fusion=false? It seems a recent change to dynamic slice fusion was potentially buggy, the author of the patch mentioned that they have run into errors caused by it. |
We've looked into HLOs with @jaro-sevcik and noticed that the only diff between exploding and non-exploding variants is additional copies introduced by rematerialization pass. Disabling it with @akuegel I'll also try flag you've mentioned, give me couple of minutes. |
Nope, setting --xla_gpu_enable_dynamic_slice_fusion=false doesn't help, gradients are still exploding. |
Imported from GitHub PR #18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix #17922 Copybara import of the project: -- 49daad1 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1 PiperOrigin-RevId: 684709374
Imported from GitHub PR openxla/xla#18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix openxla/xla#17922 Copybara import of the project: -- 49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605 PiperOrigin-RevId: 684709374
Imported from GitHub PR #18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix #17922 Copybara import of the project: -- 49daad1 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1 PiperOrigin-RevId: 684729231
Imported from GitHub PR openxla/xla#18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix openxla/xla#17922 Copybara import of the project: -- 49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605 PiperOrigin-RevId: 684729231
@jaro-sevcik has created #18152 to fix this. |
Imported from GitHub PR #18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix #17922 Copybara import of the project: -- 49daad1 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1 PiperOrigin-RevId: 684729231
Imported from GitHub PR openxla/xla#18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix openxla/xla#17922 Copybara import of the project: -- 49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605 PiperOrigin-RevId: 684729231
Imported from GitHub PR #18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix #17922 Copybara import of the project: -- 49daad1 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1 PiperOrigin-RevId: 684729231
Imported from GitHub PR openxla/xla#18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix openxla/xla#17922 Copybara import of the project: -- 49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605 PiperOrigin-RevId: 684729231
Imported from GitHub PR #18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix #17922 Copybara import of the project: -- 49daad1 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1 PiperOrigin-RevId: 684729231
Imported from GitHub PR openxla/xla#18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix openxla/xla#17922 Copybara import of the project: -- 49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605 PiperOrigin-RevId: 684729231
Imported from GitHub PR #18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix #17922 Copybara import of the project: -- 49daad1 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1 PiperOrigin-RevId: 684729231
Imported from GitHub PR openxla/xla#18152 Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass. This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested. This should fix openxla/xla#17922 Copybara import of the project: -- 49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>: Avoid fusion-wrapping copies Merging this change closes #18152 PiperOrigin-RevId: 686055013
Description
I'm training LLAMA-3.1-like transformer architecture in Hybrid Sharded Data Parallel-Context Parallel setup on 32GPUs.
Upgrading to JAX 0.4.33 has broken training of 70B model - loss becomes NaN after single training step.
Evidences that I've collected so far:
On JAX 0.4.30, 0.4.29 I've trained tens of such models with different hyperparams and datasets and have never seen any NaN.
For now, I wasn't able to reproduce this behavior on smaller models, but I'm working on it.
XLA dumps attached
xla_dump_0_4_30.tar.gz
xla_dump_0_4_33.tar.gz
System info (python version, jaxlib version, accelerator, etc.)
JAX issue
The text was updated successfully, but these errors were encountered: