Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CustomVJPException plus memory leak when using a for loop instead of a scan. #111

Open
MaxiLechner opened this issue Dec 3, 2021 · 3 comments

Comments

@MaxiLechner
Copy link

I've been using jaxopt.implicit_diff.custom_root for differentiating through an jax-md energy minimization routine and I have noticed that if I am using a python for loop for my solver then I get a CustomVJPException and an additional memory leak.

This memory leak only seems to show up when I get the CustomVJPException and not when I modify my code to prevent the exception from happening. I believe the underlying reason for that exception is the same as in issue #31 and seems to stem from the fact how jax-md defines its energy functions.

I'd like to know how to change that part of jax-md to prevent the CustomVJPException from happening in the first place but I've haven't managed to come up with a simplified version that would let me pinpoint the source of the error. But I can give it another shot if that helps you.

Here's a colab demo that demonstrates the issue.
https://colab.research.google.com/drive/1f_3EmFQpvW1p7A1AcNw8uqX5T79fjXRS?usp=sharing

@mblondel
Copy link
Collaborator

mblondel commented Dec 7, 2021

@shoyer may have an idea.

@shoyer
Copy link
Member

shoyer commented Dec 7, 2021

Yes, this looks like the same issue as #31.

You definitely need the dependency on params inside the optimality_fun passed to jaxopt.implicit_diff.custom_root. Otherwise jaxopt is not going to calculate the gradients corectly.

I think something like this should work, but it results in a different strange error:

def implicit_diff_3(params, R_init, box_size, use_for_loop=True):
    energy_fn = energy.soft_sphere_pair(displacement, **params)
    force_fn = jit(quantity.force(energy_fn))

    def optimality_fun(sol, params):
      energy_fn = energy.soft_sphere_pair(displacement, **params)
      force_fn = jit(quantity.force(energy_fn))
      return force_fn(sol)

    def solver(params, x):
        del params
        return run_minimization_scan(force_fn, x, shift, use_for_loop, num_steps = 19400)

    decorated_solver = custom_root(optimality_fun)(solver)
    R_final = decorated_solver(params, R_init)

    return (energy_fn(R_final,**params), jnp.amax(jnp.abs(force_fn(R_final,**params))))

NotImplementedError: Differentiation rule for 'custom_lin' not implemented

@mblondel
Copy link
Collaborator

mblondel commented Dec 8, 2021

I'm a bit confused by the variable names above. In JAXopt, we usually use optimality_fun(params, hyperparams) and solver_fun(init_params, hyperparams), where params is what is optimized and hyperparams is what is differentiated. See e.g. this example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants