Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion error when trying to take jacobian of projected gradient solution using projection_polyhedron #338

Open
Basant1861 opened this issue Nov 2, 2022 · 1 comment

Comments

@Basant1861
Copy link

Basant1861 commented Nov 2, 2022

The output of compute_pg is a (12 ,1) 2D array . This is a minimal reproducible example.

import jax.numpy as jnp
import numpy as np
from functools import partial
import jax
from jax import jit
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_polyhedron

num_obs = 6
C_obs_1 = 1*jnp.identity( num_obs)
C_obs_2 = -1*jnp.identity( num_obs)
C_obs = jnp.block([
                  [ C_obs_1,0*jnp.identity( num_obs)],
                  [ C_obs_2,0*jnp.identity( num_obs)],
                  [0*jnp.identity( num_obs), C_obs_1],
                  [0*jnp.identity( num_obs), C_obs_2]
                  ])
        
A_obs = jnp.zeros((1,jnp.shape(C_obs)[1]))
a_obstacle = jnp.zeros((1,1))

def compute_obstacle_penalty_temp(p):
    cost_obs_penalty = 1.0*jnp.linalg.norm(p)**2
    return cost_obs_penalty 

def proj(p,C):
    return projection_polyhedron(p,C,check_feasible = False)

def compute_pg(p):
    
    p = jnp.reshape(p,(jnp.shape(p)[0],1))
    b_obs = jnp.ones((jnp.shape(C_obs)[0],1))
    
    pg = ProjectedGradient(fun= compute_obstacle_penalty_temp,projection= proj,jit=True)
    pg_sol = pg.run(p,hyperparams_proj=( A_obs, a_obstacle, C_obs,b_obs)).params
    return pg_sol

def compute_bilevel():
  return jax.jacobian(compute_pg)(jnp.ones((12,1)))

compute_bilevel()

This is the error I get:

File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 41, in <module>
  compute_bilevel()
File "/home/ims/ros2_ws/src/mpc_python/mpc_python/plot_test.py", line 39, in compute_bilevel
  return jax.jacobian(compute_pg)(jnp.ones((12,1)))
File "/home/ims/.local/lib/python3.8/site-packages/jax/_src/api.py", line 1362, in jacfun
  jac = vmap(pullback)(_std_basis(y))
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 236, in solver_fun_bwd
  vjps = root_vjp(optimality_fun=optimality_fun, sol=sol,
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py", line 69, in root_vjp
  u = solve(matvec, v)
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 193, in solve_normal_cg
  Ab = rmatvec(b)  # A.T b
File "/home/ims/.local/lib/python3.8/site-packages/jaxopt/_src/linear_solve.py", line 145, in <lambda>
  return lambda y: transpose(y)[0]
AssertionError

What I observed was that if I set implicit_diff=False in the ProjectedGradient then it works but is super slow.Kindly advice.

@Algue-Rythme
Copy link
Collaborator

Hi Basant1861

When you differentiate a Jaxopt solver it will attempt - whenever possible - to differentiate with implicit differentiation. Implicit Differentiation is only possible if the argument you are trying to differentiate is part of the optimality conditions of your problem.

That's your issue; you are trying to differentiate with respect to p (i.e the initialization of ProjectedGradient) but it does not appear in the optimality conditions: for a convex problem p doesn't play any role. Hence mathematically the derivative should be zero anyway. Numerically, it's trickier.

Indeed, with implicit differentiation it does not work because Jaxopt cannot handle differentiating with respect to variables that are not part of optimality conditions. Without implicit differentiation, unrolling can return a value. This value should be zero in an ideal world, but with numerical errors I cannot guarantee it (I just tried and it is around 1e-9). If you want to speed it up, you can try wrapping your whole compute_pg function in jax.jit and disable implicit diff.

I suggest you take a step back a think about the meaning of the derivative you want to compute. For example, on non-convex problems the initialization has an importance because different initialization will yield different (local) optima. But in this case the function that maps p_0 to p_t (the optimum) is usually piecewise constant (each piece corresponding to a different basin of attraction).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants