-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unable to use residual offloading with scan and remat #17541
Comments
I've also tried following implementation, inspired by
As you can see, activations are indeed being offloaded during forward pass, but during forward pass they are not loaded back to devices - looks like these offloaded activations are immediately dropped on CPU and GPU activations are saved instead. That's why host memory is only 0.5GB - it is only reserved for activations of one layer. Also I've noticed that this approach produces wrong loss & grad calculations, but if I'm commenting out all of the "jax.device_put" transfers, everything works as expected again. |
Solved here |
Description
Hi guys, I'm very excited with recent activations offloading mechanism introduced in JAX/XLA:GPU but I'm unable to make it work with the scan.
My setup is the following - I'm training classic transformer with transformer block scanned over inputs "number of layers" times. I'm also using rematerialization to reduce memory footprint of the model. I basically wrap apply_block function with the jax.remat with "nothing_saveable" policy and then scan this block over inputs to achieve desired behavior - the only activations being saved during forward pass in my case is the residual stream (embeddings) in between scanned block.
With the recent introduction of the "save_and_offload_only_these_names" policy, I thought that it would be enough to mark the output of the scanned block with
jax.ad_checkpoint(x, "output")
and then specifynames_which_can_be_offloaded=["output"]
, but it didn't work.I've implemented repro to showcase what is going on:
First of all, I wanted to ensure that offloading is working in the first place.
With
I'm getting following results:
Total size device = 20.26562874764204 GB, host = 0.0 GB
Quite reasonable value.
then, I wanted to check how much would it cost to save "h" on GPU, so I set
and getting
Total size device = 35.2968789935112 GB, host = 0.0 GB
This is also totally expected as "h" is f32[32,64,8192,2048] sharded across 8 GPUs which is equals to 16GB per GPU.
Ok, let's try to offload "h" and see what happens.
Total size device = 19.75000447779894 GB, host = 16.0 GB
- also totally expected, instead of saving 16GB on GPU, we're offloading activations on host, device memory saved. Also iterations become a lot slower with is also expected.Now we sure that offloading is indeed working properly, I've tried to offload "residual" tensor (output of the scanned block).
Aaaand nothing happens -
Total size device = 20.26562874764204 GB, host = 0.0 GB
, nothing happens, no changes in memory usage, iterations is the same as no offloading at all.System info (python version, jaxlib version, accelerator, etc.)
JAX issue jax-ml/jax#23869
The text was updated successfully, but these errors were encountered: