Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does jaxopt.LBFGS required double precision? #608

Closed
joacorapela opened this issue Aug 27, 2024 · 4 comments
Closed

Does jaxopt.LBFGS required double precision? #608

joacorapela opened this issue Aug 27, 2024 · 4 comments

Comments

@joacorapela
Copy link

The PyTorch implementation of LBFGS requires double precision (please refer to this issue).

jaxopt.ScipyMinimize uses double precision by default.

Does jaxopt.LBFGS require double precision?

@mblondel
Copy link
Collaborator

No it uses the same precision as the parameters.

@joacorapela
Copy link
Author

joacorapela commented Aug 28, 2024

Thanks @mblondel .

jaxopt.ScipyMinimize says:

Note that some methods relying on FORTRAN code, such as the L-BFGS-B solver for scipy.optimize.minimize, require casting to float64.

Can I assume that this does not hold for jaxopt.LBFGS and that there should not be very large differences in the result of a jaxopt.LBFGS optimization if I use simple or double precision parameters?

@mblondel
Copy link
Collaborator

  • jaxopt.ScipyMinimize is a wrapper around SciPy, which uses a FORTRAN implementation for LBFGS-B. The FORTRAN code uses double precision.
  • jaxopt.LBFGS is a pure implementation in JAX, which supports any dtype.

@joacorapela
Copy link
Author

Thanks @mblondel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants