Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does XLA auto-sharding implement inter-op parallelism? #19103

Closed
man2machine opened this issue Nov 6, 2024 · 5 comments
Closed

Does XLA auto-sharding implement inter-op parallelism? #19103

man2machine opened this issue Nov 6, 2024 · 5 comments

Comments

@man2machine
Copy link

man2machine commented Nov 6, 2024

I found that XLA auto-sharding is based on the Alpa paper https://arxiv.org/abs/2201.12023 which proposed an algorithm for inter-op and intra-op parallelism. However, it appears that it implements only intra-op parallelism (I may be wrong), and not pipeline/inter-op parallelism. Is this true? Does XLA experimental auto-sharding only implement intra-op parallelism? Furthermore, I is this auto-sharding XLA feature available in Jax?

@patrick-toulme
Copy link
Contributor

Auto sharding does implement SPMD pipeline parallel if it decides to shard that way.

@man2machine
Copy link
Author

Thanks @ptoulme-aws for your reply! Does JAX support XLA SPMD auto-sharding? I have looked at the shard_map function as well as pjit with auto-arguments, but it is unclear whether any of these implement the SPMP pipeline + shard parallelism.

@man2machine
Copy link
Author

man2machine commented Nov 13, 2024

Once again I appreciate your response @ptoulme-aws Technically this issue is now solved, and my following question is beyond the scope, but I would greatly appreciate it if you can respond.

Given an arbitrary function f(x) I tried to write the following code:

backend = xb.get_backend()
options = xb.get_compile_options(  # type: ignore
    num_replicas=device_mesh.shape[0],
    num_partitions=device_mesh.shape[1],
    device_assignment=device_mesh,
    use_spmd_partitioning=True,
    use_auto_spmd_partitioning=True,
    auto_spmd_partitioning_mesh_shape=list(device_mesh.shape),
    auto_spmd_partitioning_mesh_ids=[d.id for d in device_mesh.flatten()]
)
input_dtype_struct = jax.ShapeDtypeStruct(x.shape, x.dtype)  # type: ignore
f_compiled = backend.compile(
    str(jax.jit(f).lower(input_dtype_struct).compiler_ir()),  # type: ignore
    compile_options=options
)
out = f_compiled.execute([x])

But I get errors that x was not sharded properly: jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Expected args to execute_sharded_on_local_devices to have 6 shards, got: [1]
(I have 6 available GPUs on my machine)

Is there a way in jax to do the auto spmd and automatically shard the inputs as well?

@patrick-toulme
Copy link
Contributor

Yes, everything should be auto partitioned even the parameters. At this point you should probably take your question to Jax Github.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants