You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.
Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.
I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.
The text was updated successfully, but these errors were encountered:
Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.
Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.
I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.
The text was updated successfully, but these errors were encountered: