Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native optimization methods failing on sharp function which scipy wrapper succeeds on #430

Open
ry-dgel opened this issue May 18, 2023 · 0 comments

Comments

@ry-dgel
Copy link

ry-dgel commented May 18, 2023

Hello!

I'm using Jax/Jaxopt for a different problem then is likely intended: optics simulations based on transfer matrix methods. Effectively, each element of a system is modeled as a matrix, and simulating a full system results in the cascaded multiplication of several matrices, making jax and jaxopt very intriguing libraries to help speed this up.

The problem I'm working with boils down to minimizing a complicated function of one parameter that essentially just results in a Lorentzian resonance occurring at a parameter value of ~10E-6 with a width of ~1E-9, so for my MVE i'm just using a lorentzian function.

Now the problem. Ideally I want to minimize using native jaxopt so that I can easily vmap and grad the result, however all the native routines seem to completely fail at optimizing this function. Meanwhile, the scipy wrapper succeeds with ease. Here's some example code for minimizing a narrow lorentzian dip with a few methods given the same starting position and tolerance.

from jax import jit,grad,vmap
import jax.numpy as jnp
import jaxopt as jopt
from jaxopt.projection import projection_box
import matplotlib.pyplot as plt
@jit
def lorentz(x,amp,mean,gamma,offset):
    delta = x-mean
    return amp * gamma**2/(gamma**2+4*delta**2) + offset

@jit
def lor(x):
    return lorentz(x,-0.8,10.0E-6,10E-10,1.0)

x_init = 10.1E-6
x_min = 9.5E-6
x_max = 10.5E-6
tol = 1E-8

xs = jnp.linspace(x_min,x_max,10000)
plt.plot(xs,lor(xs))
plt.scatter(x_init,lor(x_init),label='initial')

# Projected Gradient
solver = jopt.ProjectedGradient(lor,projection_box,tol=tol)
params = solver.run(x_init,
                    hyperparams_proj=(x_min,x_max))
plt.scatter(params.params,lor(params.params),
            label='projected',zorder=10)

# BFGS
# Stepsize avoids pinging off to large values
solver = jopt.BFGS(lor,tol=tol,stepsize=1E-9)
params = solver.run(x_init)
plt.scatter(params.params,lor(params.params),label='bfgs',zorder=10)

# Scipy Bounded LBFGSB
solver = jopt.ScipyBoundedMinimize(fun=lor,
                                   tol=tol,
                            method="l-bfgs-b")
params = solver.run(x_init,bounds=(x_min,x_max))
plt.scatter(params.params,lor(params.params),label='scipy - lbfgsb')
plt.legend()

# Scipy BFGS
solver = jopt.ScipyMinimize(fun=lor,
                            tol=tol,
                            method="bfgs")
params = solver.run(x_init)
plt.scatter(params.params,
            lor(params.params),
            label='scipy - bfgs',
            marker='x')
plt.legend()

# Scipy Newton
solver = jopt.ScipyMinimize(fun=lor,
                            tol=tol,
                            method="newton-cg")
params = solver.run(x_init)
plt.scatter(params.params,
            lor(params.params),
            label='scipy - newton',
            marker='+')
plt.legend()

And the resulting plot
scrot

Regardless of method scipy nails the dip, while the other methods blow right past it.

To note, I'm using projected gradient as ideally I want to be able to include bounds as the real problem becomes periodic in the parameter. However, the gradient should always be well behaved enough to point the solver to the nearest dip, so AFAIK I shouldn't need to use bounds, hence why bfgs/lbfgs is viable.

This might just be a question of needing to fine tune some parameters, but I haven't been able to get anything to work better than these examples demonstrate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant