-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does XLA auto-sharding implement inter-op parallelism? #19103
Comments
Auto sharding does implement SPMD pipeline parallel if it decides to shard that way. |
Thanks @ptoulme-aws for your reply! Does JAX support XLA SPMD auto-sharding? I have looked at the |
Once again I appreciate your response @ptoulme-aws Technically this issue is now solved, and my following question is beyond the scope, but I would greatly appreciate it if you can respond. Given an arbitrary function
But I get errors that x was not sharded properly: Is there a way in jax to do the auto spmd and automatically shard the inputs as well? |
Yes, everything should be auto partitioned even the parameters. At this point you should probably take your question to Jax Github. |
I found that XLA auto-sharding is based on the Alpa paper https://arxiv.org/abs/2201.12023 which proposed an algorithm for inter-op and intra-op parallelism. However, it appears that it implements only intra-op parallelism (I may be wrong), and not pipeline/inter-op parallelism. Is this true? Does XLA experimental auto-sharding only implement intra-op parallelism? Furthermore, I is this auto-sharding XLA feature available in Jax?
The text was updated successfully, but these errors were encountered: