Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Regression] Gradient explodes after upgrading to JAX 0.4.33 from 0.4.30 #17922

Closed
qGentry opened this issue Oct 4, 2024 · 11 comments · Fixed by tensorflow/tensorflow#77665 · May be fixed by #18192 or tensorflow/tensorflow#77654
Closed

Comments

@qGentry
Copy link

qGentry commented Oct 4, 2024

Description

I'm training LLAMA-3.1-like transformer architecture in Hybrid Sharded Data Parallel-Context Parallel setup on 32GPUs.
Upgrading to JAX 0.4.33 has broken training of 70B model - loss becomes NaN after single training step.
Evidences that I've collected so far:

  • Loss on the first step is exactly the same on 0.4.33 and 0.4.30
Screenshot 2024-10-04 at 11 50 55
  • Gradients of the unembedding layer and last layer norm are also exactly the same on the first step.
Screenshot 2024-10-04 at 11 51 57
  • Gradient already explodes for the last transformer layer's MLP hidden->output layer weight matrix, which I believe is the first layer after token_unembedding and last layer norm matrix.
Screenshot 2024-10-04 at 11 54 35
  • On JAX 0.4.30, 0.4.29 I've trained tens of such models with different hyperparams and datasets and have never seen any NaN.

  • For now, I wasn't able to reproduce this behavior on smaller models, but I'm working on it.

  • XLA dumps attached
    xla_dump_0_4_30.tar.gz
    xla_dump_0_4_33.tar.gz

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.24.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='computeinstance-e00xy41pgq1s49hjc5', release='5.15.0-118-generic', version='#128-Ubuntu SMP Fri Jul 5 09:28:59 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Oct  4 10:07:59 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8D:00.0 Off |                    0 |
| N/A   28C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:91:00.0 Off |                    0 |
| N/A   27C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:95:00.0 Off |                    0 |
| N/A   30C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:99:00.0 Off |                    0 |
| N/A   27C    P0            112W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:AB:00.0 Off |                    0 |
| N/A   28C    P0            109W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:AF:00.0 Off |                    0 |
| N/A   26C    P0            109W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:B3:00.0 Off |                    0 |
| N/A   29C    P0            112W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:B7:00.0 Off |                    0 |
| N/A   27C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

JAX issue

@qGentry
Copy link
Author

qGentry commented Oct 4, 2024

One important update - I've actually tracked problem down to using scan:
When I'm using scan on JAX 0.4.33, gradients are exploding, when I'm not using scan (and compilation time slows down by a factor of 50), gradients are ok. Using scan on JAX 0.4.30 has no problems - gradients are also ok.
Screenshot 2024-10-04 at 16 11 21

@qGentry
Copy link
Author

qGentry commented Oct 4, 2024

Meanwhile, single node 8GPU 8B model with very similar structure (hybrid sharded data parallelism & data & context parallelism) can't reproduce the problem - even with scan, gradients are almost identical for 0.4.30 and 0.4.33

Screenshot 2024-10-04 at 16 23 23

@qGentry
Copy link
Author

qGentry commented Oct 4, 2024

Another observation - I have run 8B model with exactly the same configuration, sharding, dataset, etc. as 70B from scratch (without restoring from checkpoint, freshly initialized).
0.4.33 8B matches 8B 0.4.30 perfectly while 0.4.33 70B explodes and 0.4.30 70B works properly.
Screenshot 2024-10-04 at 16 58 25

Here is entire configuration diff:
Screenshot 2024-10-04 at 17 06 53

@qGentry
Copy link
Author

qGentry commented Oct 4, 2024

I've also tried to iteratively transform 8B to 70B to see at which point it starts to explode. Here's my results:
8B -> ok
8B + 80 layers -> ok
8B + 80 layers + 64 attention heads -> ok
8B + 80 layers + 64 attention heads + 8192 dim -> ok
8B + 80 layers + 64 attention heads + 8192 dim + hidden_dim 28672 -> explosion (at this point model is basically equivalent to 70B)

Screenshot 2024-10-04 at 18 09 14

@qGentry
Copy link
Author

qGentry commented Oct 4, 2024

New JAX 0.4.34 just have been released.
I've tested it - unfortunately results are exactly the same as JAX 0.4.33 - gradient explosion

@qGentry
Copy link
Author

qGentry commented Oct 4, 2024

Given that 70B with hidden dim 28k explodes while one with 14k doesn't, I've bisected exact value of hidden_dim (rounded to nearest 16) at which explosions start to happen.
Here's my results - I've achieved very clear dichotomy.

Gradients are not exploding if MLP's hidden_dim <=20704
Gradients are exploding if MLP's hidden_dim>=20720.

Looks like at some point, when weights achieve certain sizes, XLA/CUDA switches to alternative algorithms, reorders something which leads to incorrect computations etc.

Screenshot 2024-10-04 at 19 22 34 Screenshot 2024-10-04 at 19 22 23

@qGentry
Copy link
Author

qGentry commented Oct 5, 2024

Here's dumped HLOs of compiled training step for hidden_dim=20704 (gradients are not exploding) and hidden_dim=20720 (gradients are exploding).
compiled_train_fn_20704.txt
compiled_train_fn_20720.txt
And also for 20688 (not exploding).
compiled_train_fn_20688.txt

@akuegel
Copy link
Member

akuegel commented Oct 7, 2024

Can you try setting the environment variable XLA_FLAGS=--xla_gpu_enable_dynamic_slice_fusion=false? It seems a recent change to dynamic slice fusion was potentially buggy, the author of the patch mentioned that they have run into errors caused by it.

@qGentry
Copy link
Author

qGentry commented Oct 7, 2024

We've looked into HLOs with @jaro-sevcik and noticed that the only diff between exploding and non-exploding variants is additional copies introduced by rematerialization pass. Disabling it with --xla_disable_hlo_passes=rematerialization seems to solve this issue.

@akuegel I'll also try flag you've mentioned, give me couple of minutes.

@qGentry
Copy link
Author

qGentry commented Oct 7, 2024

Nope, setting --xla_gpu_enable_dynamic_slice_fusion=false doesn't help, gradients are still exploding.

copybara-service bot pushed a commit that referenced this issue Oct 11, 2024
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 684709374
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 11, 2024
Imported from GitHub PR openxla/xla#18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix openxla/xla#17922
Copybara import of the project:

--
49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605
PiperOrigin-RevId: 684709374
copybara-service bot pushed a commit that referenced this issue Oct 11, 2024
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 11, 2024
Imported from GitHub PR openxla/xla#18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix openxla/xla#17922
Copybara import of the project:

--
49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605
PiperOrigin-RevId: 684729231
@akuegel
Copy link
Member

akuegel commented Oct 11, 2024

@jaro-sevcik has created #18152 to fix this.

copybara-service bot pushed a commit that referenced this issue Oct 11, 2024
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 11, 2024
Imported from GitHub PR openxla/xla#18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix openxla/xla#17922
Copybara import of the project:

--
49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit that referenced this issue Oct 11, 2024
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 11, 2024
Imported from GitHub PR openxla/xla#18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix openxla/xla#17922
Copybara import of the project:

--
49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit that referenced this issue Oct 15, 2024
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 15, 2024
Imported from GitHub PR openxla/xla#18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix openxla/xla#17922
Copybara import of the project:

--
49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1836186fd7abe2ad089aa8783f1125f605
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit that referenced this issue Oct 15, 2024
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 684729231
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 15, 2024
Imported from GitHub PR openxla/xla#18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix openxla/xla#17922
Copybara import of the project:

--
49daad1836186fd7abe2ad089aa8783f1125f605 by Jaroslav Sevcik <[email protected]>:

Avoid fusion-wrapping copies

Merging this change closes #18152

PiperOrigin-RevId: 686055013
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment