You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been using jaxopt.implicit_diff.custom_root for differentiating through an jax-md energy minimization routine and I have noticed that if I am using a python for loop for my solver then I get a CustomVJPException and an additional memory leak.
This memory leak only seems to show up when I get the CustomVJPException and not when I modify my code to prevent the exception from happening. I believe the underlying reason for that exception is the same as in issue #31 and seems to stem from the fact how jax-md defines its energy functions.
I'd like to know how to change that part of jax-md to prevent the CustomVJPException from happening in the first place but I've haven't managed to come up with a simplified version that would let me pinpoint the source of the error. But I can give it another shot if that helps you.
You definitely need the dependency on params inside the optimality_fun passed to jaxopt.implicit_diff.custom_root. Otherwise jaxopt is not going to calculate the gradients corectly.
I think something like this should work, but it results in a different strange error:
I'm a bit confused by the variable names above. In JAXopt, we usually use optimality_fun(params, hyperparams) and solver_fun(init_params, hyperparams), where params is what is optimized and hyperparams is what is differentiated. See e.g. this example.
I've been using
jaxopt.implicit_diff.custom_root
for differentiating through an jax-md energy minimization routine and I have noticed that if I am using a python for loop for my solver then I get aCustomVJPException
and an additional memory leak.This memory leak only seems to show up when I get the
CustomVJPException
and not when I modify my code to prevent the exception from happening. I believe the underlying reason for that exception is the same as in issue #31 and seems to stem from the fact how jax-md defines its energy functions.I'd like to know how to change that part of jax-md to prevent the
CustomVJPException
from happening in the first place but I've haven't managed to come up with a simplified version that would let me pinpoint the source of the error. But I can give it another shot if that helps you.Here's a colab demo that demonstrates the issue.
https://colab.research.google.com/drive/1f_3EmFQpvW1p7A1AcNw8uqX5T79fjXRS?usp=sharing
The text was updated successfully, but these errors were encountered: