You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 41, in <module>
compute_bilevel()
File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 39, in compute_bilevel
return jax.jacobian(compute_pg)(jnp.ones((12,1)))
File "/home/ims/.local/lib/python3.8/site-packages/jax/_src/api.py", line 1362, in jacfun
jac = vmap(pullback)(_std_basis(y))
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 236, in solver_fun_bwd
vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 69, in root_vjp
u = solve(matvec, v)
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 193, in solve_normal_cg
Ab = rmatvec(b) # A.T b
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 145, in <lambda>
return lambda y: transpose(y)[0]
AssertionError
What I observed was that if I set implicit_diff=False in the ProjectedGradient then it works but is super slow.Kindly advice.
The text was updated successfully, but these errors were encountered:
When you differentiate a Jaxopt solver it will attempt - whenever possible - to differentiate with implicit differentiation. Implicit Differentiation is only possible if the argument you are trying to differentiate is part of the optimality conditions of your problem.
That's your issue; you are trying to differentiate with respect to p (i.e the initialization of ProjectedGradient) but it does not appear in the optimality conditions: for a convex problem p doesn't play any role. Hence mathematically the derivative should be zero anyway. Numerically, it's trickier.
Indeed, with implicit differentiation it does not work because Jaxopt cannot handle differentiating with respect to variables that are not part of optimality conditions. Without implicit differentiation, unrolling can return a value. This value should be zero in an ideal world, but with numerical errors I cannot guarantee it (I just tried and it is around 1e-9). If you want to speed it up, you can try wrapping your whole compute_pg function in jax.jit and disable implicit diff.
I suggest you take a step back a think about the meaning of the derivative you want to compute. For example, on non-convex problems the initialization has an importance because different initialization will yield different (local) optima. But in this case the function that maps p_0 to p_t (the optimum) is usually piecewise constant (each piece corresponding to a different basin of attraction).
The output of compute_pg is a (12 ,1) 2D array . This is a minimal reproducible example.
This is the error I get:
What I observed was that if I set implicit_diff=False in the ProjectedGradient then it works but is super slow.Kindly advice.
The text was updated successfully, but these errors were encountered: