Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptaxSolver how to proceed? #65

Closed
jecampagne opened this issue Oct 11, 2021 · 6 comments
Closed

OptaxSolver how to proceed? #65

jecampagne opened this issue Oct 11, 2021 · 6 comments
Labels
question Further information is requested

Comments

@jecampagne
Copy link

jecampagne commented Oct 11, 2021

Hi,
Here is a use-case

# Volume of a box
def vol(x): 
    return x[0]*x[1]*x[2]
# Surface of the box
def surf(x):
    return 2.*(x[0]*x[1]+x[0]*x[2]+x[1]*x[2])
# Constraint on total surface
def g(x): return surf(x) - 24

#Lagrangien : p[0:3] = (x1,x2,x3), p[3] = multiplicateur de lagrange
@jax.jit
def Lag(p): 
    return vol(p[0:3]) - p[3]*g(p[0:3])
@jax.jit
def neg_Lag(p):
    return -Lag(p)

I can solve this lagrangian based problem of optimisation by hand like this:

#Gradient Lagrangien
gLag = jax.jacfwd(Lag)
hLag = jax.hessian(Lag)

def solveLagrangian(p,lr=0.1): 
    return p - lr*jnp.linalg.inv(hLag(p)) @ gLag(p)

p_cur = jnp.array([1.5,0.5,1.0,0.1])

for t in range(200):

    if (t % 10) == 0:
        print(t, p_cur, Lag(p_cur))

    new_p = solveLagrangian(p_cur)
    
    rel_err = jnp.max(jnp.abs(p_cur - new_p))
    if rel_err < 1e-6:
        print(f"Converged after {t} epochs")
        break
    
    p_cur = new_p

p_fin=p_cur
v_fin = vol(p_fin[0:3])
s_fin = surf(p_fin[0:3])

print("p_fin: ",p_fin,": True x=y=z=2, lambda=0.5" )
print("v_fin: ",v_fin,": True vol  = 2^3")
print("s_fin: ",s_fin,": True surf = 24")

I get

0 [1.5 0.5 1.  0.1] 2.6
10 [1.83358314 1.55167781 1.77716864 0.40679326] 7.609872257244211
20 [1.94433372 1.84887174 1.92882398 0.46981662] 7.95680842439189
30 [1.98087313 1.94785567 1.97583383 0.48971189] 7.9948966103715815
40 [1.99336478 1.98188252 1.99164835 0.49643995] 7.999385438210385
50 [1.99769054 1.99369051 1.99709684 0.49876193] 7.999925528308225
60 [1.99919524 1.99780094 1.9989888  0.4995687 ] 7.999990956279981
70 [1.99971946 1.99923335 1.99964755 0.49984966] 7.9999989009303
80 [1.99990219 1.9997327  1.99987712 0.49994759] 7.99999986639723
90 [1.9999659  1.9999068  1.99995716 0.49998173] 7.999999983757804
100 [1.99998811 1.9999675  1.99998506 0.49999363] 7.999999998025361
110 [1.99999585 1.99998867 1.99999479 0.49999778] 7.999999999759932
Converged after 112 epochs
p_fin:  [1.99999664 1.99999082 1.99999578 0.4999982 ] : True x=y=z=2, lambda=0.5
v_fin:  7.9999329788029625 : True vol  = 2^3
s_fin:  23.999865957438498 : True surf = 24

Okay, now is it possible to get the result with Optaxsolver

opt = optax.adagrad(0.01)
solver = jaxopt.OptaxSolver(opt=opt, fun=neg_Lag, maxiter=2000)
init_params = jnp.array([1.5,0.5,1.0,0.1])
params, state = solver.init(init_params)
print('init', params, neg_Lag(params))
for i in range(2000):
    params, state = solver.update(params=params, state=state)
    if i%100 == 0: 
        print(i, params, neg_Lag(params))

Here I get:

init [1.5 0.5 1.  0.1] -2.6
0 [1.50534522 0.50953463 1.00741998 0.10999854] -2.797381684399013
100 [1.42831204 0.63479103 1.0126782  0.28539948] -6.057683328070636
200 [1.28418844 0.59294737 0.86638545 0.37066902] -7.7856194581887825
300 [1.18203126 0.5047476  0.75689433 0.43842879] -9.33122178685714
400 [1.1023296  0.4200589  0.67028698 0.49659404] -10.755253639874878
500 [1.03695639 0.34601625 0.598102   0.54828646] -12.072981877227912
600 [0.98179502 0.28104647 0.53587287 0.59520355] -13.298704241023824
700 [0.93451151 0.22319941 0.48100081 0.63840745] -14.444779039627818
800 [0.89367761 0.17099112 0.43183736 0.67861745] -15.521400626325187
900 [0.85837707 0.12333851 0.38726128 0.71634778] -16.536982715800075
1000 [0.82800999 0.07943679 0.34646534 0.75198133] -17.49857403108653
1100 [0.80218737 0.03866963 0.30883957 0.78581251] -18.412193035926062
1200 [7.80670563e-01 5.49785239e-04 2.73902885e-01 8.18073632e-01] -19.28308214582214
1300 [ 0.76333447 -0.0353201   0.24126048  0.84895208] -20.115900881866448
1400 [ 0.75014358 -0.06927386  0.21057603  0.87860182] -20.914875993347422
1500 [ 0.7411345  -0.10159995  0.18155273  0.90715143] -21.683921834137884
1600 [ 0.7364005  -0.13255408  0.15392011  0.93470985] -22.426739725247586
1700 [ 0.73607461 -0.16236803  0.12742484  0.96137061] -23.146900944390644
1800 [ 0.74030784 -0.19125565  0.10182456  0.987215  ] -23.847914006566935
1900 [ 0.74924029 -0.21941633  0.07688466  1.01231443] -24.533272887793775

which clearly is not the right way to go..
(nb. if I use Lagas function this does not change the problem: no convergence; idem with sgd/adam...)

Is there a solution to get Optax solver working?
Thanks

@Algue-Rythme
Copy link
Collaborator

The optimizers of Optax are meant for minimization, not root finding. If we take a look at your function neg_Lag we can see that the minimum does not exist: the Lagrange multiplier x[3] is allowed to take any value in , so as long as your surface is not zero it can be used to reach any real number. Optax is working well by diverging since the minimum does not exist anyway.

So you must:

  • either use an optimizer to perform constrained optimization of your function vol; here ProjectedGradient might be deceiving since your feasible set is not convex.
  • either use a root finding algorithm to find the zero of jax.grad(Lag). Unfortunately, currently, we lack options for multidimensional root finding (best we have currently is ScipyRootFinding). Other options based on Fixed point finding will be availble soon.
  • reformulate the multidimensional root finding algorithm into an optimization problem: instead of finding p such that we seek to minimize

The latter works:

opt = optax.adagrad(0.1)

@jax.jit
def objective_fun(p):
  delta = gLag(p)
  return jnp.sum(delta**2)  # minimize gradient norm

solver = jaxopt.OptaxSolver(opt=opt, fun=objective_fun, maxiter=2000)
init_params = jnp.array([1.5,0.5,1.0,0.1])
params, state = solver.init(init_params)
print('init', params, neg_Lag(params))

@jax.jit
def jitted_update(params, state):
  return solver.update(params=params, state=state)
for i in range(20*1000):
    params, state = jitted_update(params, state)
    if i%100 == 0: 
        print(i, params, Lag(params), objective_fun(params))

However this algorithm is far from being efficient.

@mblondel mblondel added pull ready question Further information is requested and removed pull ready labels Oct 11, 2021
@jecampagne
Copy link
Author

@mblondel you are absolutely right, my problem is a root finding

$$\nabla \mathcal{L} = 0$$

and not a minimization. Sorry I have forget this point when I was using my (old) solveLagrangian that I have jaxized , it is exactly doing that root search thanks to Newton step.

Thanks for your different method discussion and snipped, too. I am not sure that I can contribute but your lib is really nice and I encourage for new code implementation.

@jecampagne
Copy link
Author

@mblondel

rf = jaxopt.ScipyRootFinding(optimality_fun=gLag, method='hybr')
rf.run(jnp.array([1.5,0.5,1.0,0.1]))

gives

OptStep(params=DeviceArray([2. , 2. , 2. , 0.5], dtype=float64), state=ScipyRootInfo(fun_val=DeviceArray([-2.05051975e-10,  3.02247116e-11, -3.61966457e-10,
              6.25949070e-10], dtype=float64), success=True, status=1))

So I wander why you reject this method in your comment? may be I have misunderstood something.

@Algue-Rythme
Copy link
Collaborator

So I wander why you reject this method in your comment? may be I have misunderstood something.

I don't reject it, I was just mentioning that we had nothing else for this purpose.

For some reason you wanted to use Optax so I showed you an example with Optax.

But ScipyRootFinding is fine too.

@mblondel
Copy link
Collaborator

Note: I am not @Algue-Rythme :)

@jecampagne
Copy link
Author

ho sorry @mblondel, I was also in contact with @Algue-Rythme in an other thread :)
Now, your code with Optax was really nice.
Thanks a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants