You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Original JAX issue: jax-ml/jax#21367. I'm using Python version: 3.12.1., jaxlib version: 0.4.28, running on a TPU v4-8 VM.
In my use case I need to do batches of dynamic_update_slice operations, and have been using JAX's vmap for that, but was getting extremely slow runtimes (between 10 and 50x slower than expected). I profiled the code and found that the vmapped dynamic_update_slice, which I expected to be doing a scatter, was actually doing a while-loop of dynamic_update_slice ops, looping over the batch axis. I think this while-loop may be preventing parallelization, and causing the very slow runtimes.
This JAX code demonstrates the slow-down by comparing the vmapped dynamic_update_slice, which lowers to a single scatter, with an equivalent unrolled Python loop of dynamic_update_slices:
fromtimeitimporttimeitfromjaximportjit, lax, vmap, make_jaxprimportjax.numpyasjnp# For f which outputs a single array, this simulates vmap using Python mappymap=lambdaf: lambda*args: jnp.stack(list(map(f, *args)))
operands=jnp.ones((100, 32))
updates=jnp.ones((100, 2))
starts=jnp.ones((100, 1), dtype='int32')
f=lax.dynamic_update_slicef_vmapped=jit(vmap(f))
f_pymapped=jit(pymap(f))
# Ensure compiledf_vmapped(operands, updates, starts)
f_pymapped(operands, updates, starts)
t_vmapped=timeit(
lambda: f_vmapped(operands, updates, starts).block_until_ready(), number=100
) /100t_pymapped=timeit(
lambda: f_pymapped(operands, updates, starts).block_until_ready(), number=100
) /100print(f"Time vmap(f): {t_vmapped:.2}s")
print(f"Time pymap(f): {t_pymapped:.2}s")
Running it on a TPU v4-8 VM I get:
Time vmap(f): 0.00088s
Time pymap(f): 0.00036s
So, to be clear, what I think could be happening is that the unrolled Python loop is faster than scatter because it can be parallelized (the loop iterations have no dependence on each other), whereas the scatter is (for some reason) compiling to a while-loop which cannot be parallelized.
The lowered StableHLO of f_vmapped does contain a scatter and no loop, as expected. Note that the unique_indices flag of the scatter is true:
Original JAX issue: jax-ml/jax#21367. I'm using Python version: 3.12.1., jaxlib version: 0.4.28, running on a TPU v4-8 VM.
In my use case I need to do batches of
dynamic_update_slice
operations, and have been using JAX's vmap for that, but was getting extremely slow runtimes (between 10 and 50x slower than expected). I profiled the code and found that the vmappeddynamic_update_slice
, which I expected to be doing ascatter
, was actually doing a while-loop ofdynamic_update_slice
ops, looping over the batch axis. I think this while-loop may be preventing parallelization, and causing the very slow runtimes.This JAX code demonstrates the slow-down by comparing the vmapped
dynamic_update_slice
, which lowers to a singlescatter
, with an equivalent unrolled Python loop ofdynamic_update_slice
s:Running it on a TPU v4-8 VM I get:
So, to be clear, what I think could be happening is that the unrolled Python loop is faster than
scatter
because it can be parallelized (the loop iterations have no dependence on each other), whereas thescatter
is (for some reason) compiling to a while-loop which cannot be parallelized.The lowered StableHLO of
f_vmapped
does contain a scatter and no loop, as expected. Note that theunique_indices
flag of thescatter
istrue
:However, after optimization/compilation the HLO contains a while loop with a
dynamic-update-slice
in the body:The text was updated successfully, but these errors were encountered: