JAXopt with nonlinear optimization and neural networks #177
-
I would like to train a neural network architecture with a nonlinear optimization algorithm like IPOPT or SNOPT as a layer in the architecture, similar in spirit to the MPC layer in https://locuslab.github.io/mpc.pytorch/. It seems like JAXopt’s implicit differentiation for custom fixed point solvers (https://jaxopt.github.io/stable/_autosummary/jaxopt.implicit_diff.custom_fixed_point.html#jaxopt.implicit_diff.custom_fixed_point) is relevant, but I'm not too sure how to make use of it exactly. Or perhaps https://jaxopt.github.io/stable/_autosummary/jaxopt.implicit_diff.custom_root.html#jaxopt-implicit-diff-custom-root? Can you give some guidance as to where to start? Thanks! Example copied from https://github.com/mechmotum/cyipopt/blob/master/examples/hs071_scipy_jax.py import jax
from jax import jit, grad, jacfwd, jacrev
import jax.numpy as jnp
from jaxopt import implicit_diff
from cyipopt import minimize_ipopt
import numpy as np
def objective(x):
return x[0]*x[3]*jnp.sum(x[:3]) + x[2]
def eq_constraints(x):
return jnp.sum(x**2) - 40
def ineq_constraints(x):
return jnp.prod(x) - 25
# jit the functions
obj_jit = jit(objective)
con_eq_jit = jit(eq_constraints)
con_ineq_jit = jit(ineq_constraints)
# build the derivatives and jit them
obj_grad = jit(grad(obj_jit)) # objective gradient
obj_hess = jit(jacrev(jacfwd(obj_jit))) # objective hessian
con_eq_jac = jit(jacfwd(con_eq_jit)) # jacobian
con_ineq_jac = jit(jacfwd(con_ineq_jit)) # jacobian
con_eq_hess = jacrev(jacfwd(con_eq_jit)) # hessian
con_eq_hessvp = jit(lambda x, v: con_eq_hess(x) * v[0]) # hessian vector-product
con_ineq_hess = jacrev(jacfwd(con_ineq_jit)) # hessian
con_ineq_hessvp = jit(lambda x, v: con_ineq_hess(x) * v[0]) # hessian vector-product
# constraints
cons = [
{'type': 'eq', 'fun': con_eq_jit, 'jac': con_eq_jac, 'hess': con_eq_hessvp},
{'type': 'ineq', 'fun': con_ineq_jit, 'jac': con_ineq_jac, 'hess': con_ineq_hessvp},
]
# variable bounds: 1 <= x[i] <= 5
bnds = [(1, 5) for _ in range(4)]
# my guess as to what the objective should be
def complete_objective(x):
return objective(x) + eq_constraints(x) + ineq_constraints(x)
@implicit_diff.custom_root(jax.grad(complete_objective))
def ipopt_solver(x0):
# solve
display_opts = 0 # 5 for full info
tol = 1e-7
sol = minimize_ipopt(obj_jit, jac=obj_grad, hess=obj_hess, x0=x0, bounds=bnds,
constraints=cons, options={"print_level": display_opts, "tol": tol})
sol = jnp.array(sol.x)
return sol
# x0 (input to ipopt_solver(x0)) is the initial value given to the solver, it is the output from the neural net |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi, It's almost that ! But actually you cannot simply add constraints and objective the way you did: For one to enable implicit differentiation of constrained optimization problems you need both primal and dual variables (see page 5 in https://arxiv.org/pdf/2105.15183.pdf ). There is a way to do that with a function specially made to handle KKT conditions: jaxopt/jaxopt/_src/implicit_diff.py Line 330 in b77c9f5 You will find example of usage here As you can see the optimality fun is build from In the notebook I sent you above the function The problem you gave in example do not exhibit such parameters; so there is nothing to differentiate currently (implicit diff is useless here). Maybe consider replacing 40 by |
Beta Was this translation helpful? Give feedback.
Hi,
It's almost that ! But actually you cannot simply add constraints and objective the way you did:
objective(x) + eq_constraints(x) + ineq_constraints(x)
. It does not make sense.For one to enable implicit differentiation of constrained optimization problems you need both primal and dual variables (see page 5 in https://arxiv.org/pdf/2105.15183.pdf ). There is a way to do that with a function specially made to handle KKT conditions:
jaxopt._src.implicit_diff.make_kkt_optimality_fun
. This method can be found here:jaxopt/jaxopt/_src/implicit_diff.py
Line 330 in b77c9f5
You will find example of usage here
As yo…