-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Infinities and NaNs in quadratic_prog when c=0 #95
Comments
In case it helps, for my research code I've manually implemented the same call that QP would've done except using |
Hi Ferran That's weird, I tried to reproduce your bug and did not succeed: https://colab.research.google.com/drive/1-IS1MIkkXfVuON5IhAT2gxt2pw-5IVz8?usp=sharing Can you give more details on the versions of Python/Jax/Jaxopt you are using ? Are you working on CPU or GPU ? In the default If you have a notebook in which you can consistently reproduce the bug that would be great. Is it diverging for primal/dual computation in run(), or during implicit differentiation ? |
My bad, it was because I had the Thanks! |
Great that the problem disappeared. Not sure what fixed it, maybe ac1bdcd. See https://github.com/google/jaxopt/commits/main/jaxopt/_src/quadratic_prog.py for the history of this file. |
@mblondel @Algue-Rythme during my research I just ran into more NaNs on the default QP solver, so it may not have been fully solved in 0.1.1. It happens roughly for ~5% of the examples and I haven't found any pattern on when. On the other 95% it matches (modulo small precision errors) the result of using The failure case is quite entangled with my research code so I can't send it atm. If I find a failure pattern I'll design a minimal example and send it to you. |
My guess is that the issue is in In #98, @Algue-Rythme is working on a new class |
Hi,
I'm using QuadraticProgramming in the special case of c=0 (all zeros as a vector). AFAIK this is still well-defined, as it's just minimizing l2 norm squared of the primal subject to some equality constraints (I don't have inequalities).
However, both my research code and the following modification of this test diverge even for a single step (
maxiter=1
).The modification just involves setting c=0, so:
Is there a way to fix it? If it involves calling another linear solver, is there a way to specify the solver from the high-level QP function? I haven't seen it.
Thanks!
The text was updated successfully, but these errors were encountered: