Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use residual offloading with scan and remat #17541

Closed
qGentry opened this issue Sep 24, 2024 · 2 comments
Closed

Unable to use residual offloading with scan and remat #17541

qGentry opened this issue Sep 24, 2024 · 2 comments

Comments

@qGentry
Copy link

qGentry commented Sep 24, 2024

Description

Hi guys, I'm very excited with recent activations offloading mechanism introduced in JAX/XLA:GPU but I'm unable to make it work with the scan.
My setup is the following - I'm training classic transformer with transformer block scanned over inputs "number of layers" times. I'm also using rematerialization to reduce memory footprint of the model. I basically wrap apply_block function with the jax.remat with "nothing_saveable" policy and then scan this block over inputs to achieve desired behavior - the only activations being saved during forward pass in my case is the residual stream (embeddings) in between scanned block.

With the recent introduction of the "save_and_offload_only_these_names" policy, I thought that it would be enough to mark the output of the scanned block with jax.ad_checkpoint(x, "output") and then specify names_which_can_be_offloaded=["output"], but it didn't work.

I've implemented repro to showcase what is going on:

import flax.linen as nn
import jax
import jax.ad_checkpoint
import jax.numpy as jnp
import numpy as np
from flax.linen.linear import default_kernel_init

EMB_DIM = 2048
HID_DIM = 2048

BS = 64
SEQ_LEN = 8192
N_LAYERS = 32


CHECKPOINT_POLICY = jax.checkpoint_policies.save_and_offload_only_these_names(
    names_which_can_be_saved=[],
    names_which_can_be_offloaded=[],
    offload_src="device",
    offload_dst="pinned_host",
)


mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ("data", "model"))
input_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data", None)
)
target_sharding = jax.sharding.NamedSharding(
    mesh,
    jax.sharding.PartitionSpec(
        "data",
    ),
)
rules = (
    ("batch", "data"),
    ("embedding", None),
    ("hidden", "model"),
    ("q_sequence", "model"),
)


class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x_residual = x
        h = nn.Dense(
            HID_DIM,
            kernel_init=nn.with_logical_partitioning(
                default_kernel_init,
                ("embedding", "hidden"),
            ),
            use_bias=False,
        )(x)
        h = jax.ad_checkpoint.checkpoint_name(h, "hidden")
        h = nn.relu(h)
        
        x = nn.Dense(
            EMB_DIM,
            kernel_init=nn.with_logical_partitioning(
                default_kernel_init,
                ("hidden", "embedding"),
            ),
            use_bias=False,
        )(h)
        x = x_residual + x
        # Sequence parallelism
        x = nn.with_logical_constraint(x, ("batch", "q_sequence", None))
        x = jax.ad_checkpoint.checkpoint_name(x, "residual")
        return x


class Output(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(
            features=1,
            kernel_init=nn.with_logical_partitioning(
                default_kernel_init,
                ("hidden", None),
            ),
            use_bias=False,
        )(x)[..., 0]
        x = jnp.mean(x, axis=1)
        return x


class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        def apply_module(block, block_input, _):
            block_output = block(block_input)
            return block_output, None

        apply_module = nn.remat(
            apply_module,
            policy=CHECKPOINT_POLICY,
            prevent_cse=False,
        )

        x, _ = nn.scan(
            apply_module,
            variable_axes={"params": 0},
            split_rngs={"params": True},
            length=N_LAYERS,
            metadata_params={nn.PARTITION_NAME: "layers"},
        )(MLP(), x, None)

        preds = Output()(x)
        return preds


def loss_fn(preds, target):
    return jnp.mean((preds - target) ** 2)


def calc_loss(params, inputs, target):
    preds = Model().apply(params, inputs)
    loss = loss_fn(preds, target)
    return loss


def train_step(params, inputs, target):
    loss, grads = jax.value_and_grad(calc_loss)(params, inputs, target)
    params = jax.tree_util.tree_map(lambda p, g: p - 1e-8 * g, params, grads)
    return params, loss


def unbox_logically_partioned(tree, apply_constraint: bool = True):
    return jax.tree_util.tree_map(
        lambda leaf: (
            leaf.unbox(apply_constraint=apply_constraint)
            if isinstance(leaf, nn.LogicallyPartitioned)
            else leaf
        ),
        tree,
        is_leaf=lambda node: isinstance(node, nn.LogicallyPartitioned),
    )


def get_gpu_memory_usage() -> dict[str, float]:
    if jax.default_backend() != "gpu":
        return {}
    num_devices = jax.local_device_count("gpu")
    gpu_memory_usage = []
    for i in range(num_devices):
        memory_stats = jax.local_devices()[i].memory_stats()
        gpu_memory_usage.append(
            memory_stats["peak_bytes_in_use"] / memory_stats["bytes_limit"] * 100
        )
    return {f"GPU{i}": val for i, val in enumerate(gpu_memory_usage)}


with mesh, nn.logical_axis_rules(rules):
    fake_inputs = jnp.empty((BS, SEQ_LEN, EMB_DIM))
    fake_inputs = jax.device_put(fake_inputs, input_sharding)
    fake_target = jnp.empty((BS,))
    fake_target = jax.device_put(fake_target, target_sharding)

    params = Model().init(jax.random.PRNGKey(0), fake_inputs)
    params = unbox_logically_partioned(params)

    train_step_fn = (
        jax.jit(
            train_step,
            in_shardings=(
                jax.tree_util.tree_map(lambda x: x.sharding, params),
                input_sharding,
                target_sharding,
            ),
            out_shardings=(
                jax.tree_util.tree_map(lambda x: x.sharding, params),
                jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()),
            ),
            donate_argnums=(0,),
        )
        .lower(params, fake_inputs, fake_target)
        .compile()
    )
    jax.ad_checkpoint.print_saved_residuals(
        train_step, params, fake_inputs, fake_target
    )

    with open("compiled.txt", "w") as f:
        f.write(train_step_fn.as_text())

    memory_analysis = train_step_fn.memory_analysis()
    print(
        f"Total size device = {memory_analysis.temp_size_in_bytes / 1024 / 1024 / 1024} GB, "  # noqa E501
        f"host = {memory_analysis.host_temp_size_in_bytes / 1024 / 1024 / 1024} GB"
    )

    for i in range(10):
        inputs = jax.random.normal(jax.random.PRNGKey(i), (BS, SEQ_LEN, EMB_DIM))
        inputs = jax.device_put(inputs, input_sharding)

        target = jax.random.normal(jax.random.PRNGKey(0), (BS,))
        target = jax.device_put(target, target_sharding)

        params, loss = train_step_fn(params, inputs, target)
        print(loss)
        print(get_gpu_memory_usage())

First of all, I wanted to ensure that offloading is working in the first place.
With

    names_which_can_be_saved=[],
    names_which_can_be_offloaded=[],

I'm getting following results:
Total size device = 20.26562874764204 GB, host = 0.0 GB
Quite reasonable value.

then, I wanted to check how much would it cost to save "h" on GPU, so I set

    names_which_can_be_saved=["hidden"],
    names_which_can_be_offloaded=[],

and getting
Total size device = 35.2968789935112 GB, host = 0.0 GB
This is also totally expected as "h" is f32[32,64,8192,2048] sharded across 8 GPUs which is equals to 16GB per GPU.

Ok, let's try to offload "h" and see what happens.

    names_which_can_be_saved=[],
    names_which_can_be_offloaded=["hidden"],

Total size device = 19.75000447779894 GB, host = 16.0 GB - also totally expected, instead of saving 16GB on GPU, we're offloading activations on host, device memory saved. Also iterations become a lot slower with is also expected.

Now we sure that offloading is indeed working properly, I've tried to offload "residual" tensor (output of the scanned block).

    names_which_can_be_saved=[],
    names_which_can_be_offloaded=["residual"],

Aaaand nothing happens - Total size device = 20.26562874764204 GB, host = 0.0 GB, nothing happens, no changes in memory usage, iterations is the same as no offloading at all.

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.24.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='end-llm-computeinstance-e00yhypr7caccaxmct.priv.hw.nebius.yt', release='5.15.0-119-generic', version='#129-Ubuntu SMP Fri Aug 2 19:25:20 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Sep 24 10:15:19 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8D:00.0 Off |                    0 |
| N/A   34C    P0            114W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:91:00.0 Off |                    0 |
| N/A   31C    P0            117W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:95:00.0 Off |                    0 |
| N/A   36C    P0            119W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:99:00.0 Off |                    0 |
| N/A   30C    P0            113W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:AB:00.0 Off |                    0 |
| N/A   34C    P0            118W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:AF:00.0 Off |                    0 |
| N/A   31C    P0            116W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:B3:00.0 Off |                    0 |
| N/A   35C    P0            115W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:B7:00.0 Off |                    0 |
| N/A   30C    P0            114W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

JAX issue jax-ml/jax#23869

@qGentry
Copy link
Author

qGentry commented Sep 24, 2024

I've also tried following implementation, inspired by
jax-ml/jax#23614 (comment)
with wrapping entire apply_module with flax's custom_vjp, but it doesn't work properly.

class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        def apply_module(block, block_input, _):
            block_output = block(block_input)
            return block_output, None

        def apply_module_fwd(block, block_input, _):
            res, vjp_fn = nn.vjp(apply_module, block, block_input, _)
            emb, _ = res
            emb = jax.device_put(emb, TransferToMemoryKind("pinned_host"))
            return (emb, None), vjp_fn

        def apply_module_bwd(vjp_fn, res):
            emb, _ = res
            emb = jax.device_put(emb, TransferToMemoryKind("device"))
            res = (emb, None)
            return vjp_fn(res)

        apply_module_vjp = nn.custom_vjp(
            apply_module,
            forward_fn=apply_module_fwd,
            backward_fn=apply_module_bwd
        )

        apply_module_vjp = nn.remat(
            apply_module_vjp,
            policy=CHECKPOINT_POLICY,
            prevent_cse=False,
        )

        x, _ = nn.scan(
            apply_module_vjp,
            variable_axes={"params": 0},
            split_rngs={"params": True},
            length=N_LAYERS,
            metadata_params={nn.PARTITION_NAME: "layers"},
        )(MLP(), x, None)

        x = jax.device_put(x, TransferToMemoryKind("device"))

        preds = Output()(x)
        return preds

Total size device = 21.00781624764204 GB, host = 0.5 GB.

Here's part of the trace
Screenshot 2024-09-24 at 14 53 29

As you can see, activations are indeed being offloaded during forward pass, but during forward pass they are not loaded back to devices - looks like these offloaded activations are immediately dropped on CPU and GPU activations are saved instead. That's why host memory is only 0.5GB - it is only reserved for activations of one layer.

Also I've noticed that this approach produces wrong loss & grad calculations, but if I'm commenting out all of the "jax.device_put" transfers, everything works as expected again.

@qGentry
Copy link
Author

qGentry commented Sep 25, 2024

Solved here
jax-ml/jax#23869 (comment)

@qGentry qGentry closed this as completed Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant