You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using Jax/Jaxopt for a different problem then is likely intended: optics simulations based on transfer matrix methods. Effectively, each element of a system is modeled as a matrix, and simulating a full system results in the cascaded multiplication of several matrices, making jax and jaxopt very intriguing libraries to help speed this up.
The problem I'm working with boils down to minimizing a complicated function of one parameter that essentially just results in a Lorentzian resonance occurring at a parameter value of ~10E-6 with a width of ~1E-9, so for my MVE i'm just using a lorentzian function.
Now the problem. Ideally I want to minimize using native jaxopt so that I can easily vmap and grad the result, however all the native routines seem to completely fail at optimizing this function. Meanwhile, the scipy wrapper succeeds with ease. Here's some example code for minimizing a narrow lorentzian dip with a few methods given the same starting position and tolerance.
Regardless of method scipy nails the dip, while the other methods blow right past it.
To note, I'm using projected gradient as ideally I want to be able to include bounds as the real problem becomes periodic in the parameter. However, the gradient should always be well behaved enough to point the solver to the nearest dip, so AFAIK I shouldn't need to use bounds, hence why bfgs/lbfgs is viable.
This might just be a question of needing to fine tune some parameters, but I haven't been able to get anything to work better than these examples demonstrate.
The text was updated successfully, but these errors were encountered:
Hello!
I'm using Jax/Jaxopt for a different problem then is likely intended: optics simulations based on transfer matrix methods. Effectively, each element of a system is modeled as a matrix, and simulating a full system results in the cascaded multiplication of several matrices, making jax and jaxopt very intriguing libraries to help speed this up.
The problem I'm working with boils down to minimizing a complicated function of one parameter that essentially just results in a Lorentzian resonance occurring at a parameter value of ~10E-6 with a width of ~1E-9, so for my MVE i'm just using a lorentzian function.
Now the problem. Ideally I want to minimize using native jaxopt so that I can easily
vmap
andgrad
the result, however all the native routines seem to completely fail at optimizing this function. Meanwhile, the scipy wrapper succeeds with ease. Here's some example code for minimizing a narrow lorentzian dip with a few methods given the same starting position and tolerance.And the resulting plot
Regardless of method scipy nails the dip, while the other methods blow right past it.
To note, I'm using projected gradient as ideally I want to be able to include bounds as the real problem becomes periodic in the parameter. However, the gradient should always be well behaved enough to point the solver to the nearest dip, so AFAIK I shouldn't need to use bounds, hence why bfgs/lbfgs is viable.
This might just be a question of needing to fine tune some parameters, but I haven't been able to get anything to work better than these examples demonstrate.
The text was updated successfully, but these errors were encountered: