Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different results between jaxopt.LBFGS and jaxopt.ScipyMinimize(method - 'l-bfgs-b') #275

Open
richinex opened this issue Jul 14, 2022 · 4 comments
Labels
question Further information is requested

Comments

@richinex
Copy link

I am trying to run a complex nonlinear optimization on a multi-dimensional data using vmap on the solver.run. Since I could not use the l-bfgs-b method in the ScipyMinimize wrapper, I resorted to the jaxopt.LBFGS. However I realized that the result from the latter was not correct. I would like to know why and what I could do. My minimal working example is shown below. Thanks

# Load libraries
import numpy as np
import scipy.sparse as sps
import jax
import jaxopt
from jax import value_and_grad
from jax import numpy as jnp
from jax.example_libraries import optimizers as jax_opt
jax.config.update("jax_enable_x64", True)

# Make data
F_arr =jnp.array([4.000000e+00, 1.250000e+01, 2.100000e+01, 2.950000e+01,3.800000e+01, 4.650000e+01, 6.350000e+01, 8.050000e+01,
        1.230000e+02, 1.570000e+02, 1.995000e+02, 2.505000e+02,3.185000e+02, 3.940000e+02, 4.965000e+02, 6.260000e+02,7.925000e+02, 9.975000e+02, 1.256000e+03, 1.581000e+03,
        1.990500e+03, 2.506000e+03, 3.155000e+03, 3.971500e+03,5.000000e+03, 6.294500e+03, 7.924500e+03, 9.976500e+03,1.255950e+04, 1.581150e+04, 1.990550e+04, 2.505950e+04,
        3.154800e+04, 3.971650e+04, 5.000000e+04])


Y_arr = jnp.array([0.00495074+0.00290374j, 0.00724701+0.00289439j,0.00821288+0.00279885j, 0.00877054+0.00276919j,
        0.00921332+0.0027551j , 0.00953043+0.00274739j,0.01002155+0.00274946j, 0.01038829+0.00279736j,
        0.01103745+0.00293741j, 0.01143682+0.00304808j,0.01185019+0.00321095j, 0.01222892+0.00340771j,
        0.01264666+0.00365856j, 0.01312294+0.00390083j,0.01356835+0.00423682j, 0.01414305+0.00459166j,
        0.01475416+0.00502188j, 0.01544523+0.0054795j ,0.01620464+0.00597393j, 0.01707565+0.00650766j,
        0.01800564+0.00707323j, 0.01907494+0.00766403j,0.0202539 +0.00824607j, 0.02156295+0.00882627j,
        0.02293967+0.0093636j , 0.02446602+0.00988404j,0.02606663+0.01034258j, 0.02778773+0.01073912j,
        0.0295645 +0.01105176j, 0.03142458+0.01130524j,0.03332406+0.01142638j, 0.03529196+0.01141756j,
        0.03725344+0.01128458j, 0.03917468+0.01100424j,0.04104471+0.0105539j], dtype='complex64')


sigma_arr = jnp.array([2.43219802e-06, 3.84912892e-06, 4.65468565e-06, 5.23176095e-06,5.68508176e-06, 6.05872401e-06, 6.64642994e-06, 7.11385064e-06,
        7.95151846e-06, 8.43719499e-06, 8.92535354e-06, 9.37367440e-06,9.85436691e-06, 1.02790955e-05, 1.07571723e-05, 1.12416874e-05,
        1.17638756e-05, 1.22902720e-05, 1.28422944e-05, 1.34043157e-05,1.39690355e-05, 1.45196518e-05, 1.50516798e-05, 1.55538437e-05,
        1.60391119e-05, 1.65177389e-05, 1.69736650e-05, 1.74361339e-05,1.78881437e-05, 1.83461307e-05, 1.87868263e-05, 1.92436037e-05,
        1.96903675e-05, 2.01275998e-05, 2.05546330e-05])

params_init = jnp.array([1.84285135e+01, 1.71039097e-05, 6.98550706e-01,
             7.33632243e-01, 4.77681912e+02, 5.65632259e-04,
             1.34721147e+01, 1.34025052e+02, 3.93700063e+00,
             2.96283162e-01, 2.31503009e-01])

parameter_bounds = [[1e-1,  1e6], [1e-7, 1e-1], [1e-1, 1], [1e-1, 1e7], [1e-1,  1e7], [1e-7, 1e-1], [1e-1,  1e6], [1e-1,  1e7], [1e-1,  1e7], [1e-1, 1], [1e-1,  1e7]]
lb = jnp.array([i[0] for i in parameter_bounds])
ub = jnp.array([i[1] for i in parameter_bounds])

n_par = len(params_init)
n_data = 5
n_freq = len(F_arr)

# form a matrix from the initial parameters and the bounds
par_mat = jnp.broadcast_to(params_init[:,None], (len(params_init), n_data))
lb_mat = jnp.broadcast_to(lb[:,None], (len(params_init), n_data))
ub_mat = jnp.broadcast_to(ub[:,None], (len(params_init), n_data))

# convert external to internal parameters
par_log = jnp.log10((par_mat - lb_mat) / (1-par_mat / ub_mat))

# create a matrix from F_arr, Y_arr and sigma_arr
F = jnp.broadcast_to(F_arr[:,None], (n_freq, n_data))
Y =  jnp.broadcast_to(Y_arr[:,None], (n_freq, n_data))
sigma_Y = jnp.broadcast_to(sigma_arr[:,None], (n_freq, n_data))

# Define model
@jax.jit
def fun(p, f):
    w = 2*jnp.pi*f
    Rs = p[0]
    Qh = p[1]*p[10]
    nh = p[2]
    Rad = p[3]/p[10]
    Wad = p[4]/p[10]
    Cad = p[5]*p[10]
    Rint = p[6]/p[10]
    Wint = p[7]/p[10]
    tau = p[8]
    alpha = p[9]
    Rp = p[10]
    Ct = (1/Cad)**-1
    Zad = Rad + Wad/jnp.sqrt(1j*w)
    Zint = Rint + Wint/((1j*w*tau)**(alpha/2)) * 1/(jnp.tanh((1j*w*tau)**(alpha/2)))
    Yf = (Zad + (1j*w*Ct)**-1)/(Zad*Zint + (Zad+Zint)*(1j*w*Ct)**-1)
    Ydl = Qh*((1j*w)**nh)
    Kl = jnp.sqrt(Ydl + Yf)
    Z = Rs + Rp * jnp.tanh(Kl)**-1 / Kl 
    Y = 1/Z 
    return jnp.concatenate([Y.real, Y.imag], axis = 0)




# sum of squares residual
@jax.jit
def obj_fun(p, x, y, yerr):
    ndata = len(x)
    dof = (2*ndata-(len(p)))
    y_concat = jnp.concatenate([y.real, y.imag], axis = 0)
    sigma = jnp.concatenate([yerr,yerr], axis = 0)
    y_model = fun(p, x)
    # chi_sqr = ((jnp.abs((1/sigma) * (y_concat - y_model))))
    chi_sqr = jnp.linalg.norm(((y_concat - y_model)/sigma))**2
    return (chi_sqr)

# Multidimensional cost function
@jax.jit
def cost_fun(P, X, Y, YERR, LB, UB):
    dof = (2*len(X[0])*len(X))-len(P)
    P_norm = (LB + 10**P) / (1 + 10**P / UB)
    chi = jax.vmap(obj_fun, in_axes=1)(P_norm, X, Y, YERR)
    return jnp.sum(chi) / dof

# Run the optimization
solver_1 = jaxopt.ScipyMinimize(method = "l-bfgs-b", fun=cost_fun, tol = 1e-12, options ={'maxiter':5000})
solver_1_sol = solver_1.run(par_log, F, Y, sigma_Y, lb_mat, ub_mat)
solver_1_sol.params[:, 0]

# Correct result
# DeviceArray([ 1.16801937e+00, -4.40860838e+00,  1.70230450e-01,
#              -1.91683037e+00,  2.81482271e+00, -3.33403710e+00,
#               1.58763377e+00,  1.96265752e+00, -9.80790594e+02,
#              -4.01386155e-01,  1.04828148e+00], dtype=float64)


solver_2 = jaxopt.LBFGS(fun=cost_fun, maxiter = 5000)
solver_2_sol = solver_2.run(par_log, X=F, Y=Y, YERR=sigma_Y, LB=lb_mat, UB=ub_mat)
solver_2_sol.params[:, 0]

# # Incorrect result
# DeviceArray([ 5.13495057e+04, -1.12978597e+04, -1.15845934e+04,
#               6.74886483e+01,  2.41560694e+03, -1.42666112e+04,
#               3.63855005e+04,  1.28521409e+05, -1.94682351e+04,
#              -7.21753278e+04,  1.21914493e+02], dtype=float64)

@mblondel
Copy link
Collaborator

Could you try also with jaxopt.LBFGS(..., linesearch="zoom")?

jaxopt.LBFGS and LBFGS-B from SciPy don't use the same line search technique so it's possible that we don't get the same results sometimes, if the function to be minimized is nonconvex.

@richinex
Copy link
Author

I also did try with the zoom line search and did not get the correct results. You're right the problem is nonconvex.

@richinex
Copy link
Author

Nevertheless, I found that I could use list(map(func, *args)) instead of vmap with jaxopt.scipy.minimize and it temporarily solves my problem

@mblondel mblondel added the question Further information is requested label Aug 25, 2022
@zaccharieramzi
Copy link
Contributor

@richinex could be related to this, and therefore with the fixes in #323 and #350 it might make the results consistent with that of core JAX. Maybe you can give it a shot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants