Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible memory leak when calling solver.run multiple times #380

Open
alucantonio opened this issue Jan 18, 2023 · 17 comments
Open

Possible memory leak when calling solver.run multiple times #380

alucantonio opened this issue Jan 18, 2023 · 17 comments

Comments

@alucantonio
Copy link

I am trying to solve a problem where solver.run is called multiple times to minimize a series of functions while varying a parameter. Using memory_profiler I can see that the allocated memory increases each time the function solver.run is called and never decreases.

Here is a minimal example to reproduce the issue:

import jax.numpy as jnp
import jaxopt
from memory_profiler import profile

@profile
def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)

And here is the corresponding plot of the allocated memory:
Figure_1

Can you please confirm the issue or provide a solution for that? Thanks. Alessandro

@mblondel
Copy link
Collaborator

Is it specific to LBFGS or does it happen with any solver?

@alucantonio
Copy link
Author

Hi, I have experienced the issue with LBFGS and GradientDescent. The increase in memory is less evident with GradientDescent, but it is still there. I believe the issue does not depend on the solver.

@mblondel
Copy link
Collaborator

Can you also check if the jit and unroll options to LBFGS have any impact on this? Depending on these options, a different loop implementation is used.

Normally, solver objects don't store anything, so I'm not sure where this could come from...

CC @fabianp, @froystig

@alucantonio
Copy link
Author

Setting jit=False and unroll=True or jit=True and unroll=False and using LBFGS still produces an increase in memory after each call of solver.run.

@fabianp
Copy link
Collaborator

fabianp commented Feb 8, 2023

I don't have a solution for this, but can confirm that it happens also with the update API, i.e., when updates are run inside a for loop:

def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, implicit_diff=False, maxiter=100)
    state = solver.init_state(x0, min=mm)
    jitted_update = jax.jit(solver.update)
    params = x0
    for _ in range(solver.maxiter):
        params, state = jitted_update(params, state, min=mm)

@fabianp
Copy link
Collaborator

fabianp commented Feb 14, 2023

Some updates on my investigations.

  1. Upon @mblondel's idea, I set eq=True in the definition of LBFGS. It didn't help.
  2. I also modified the LBFGS class to remote the dataclass decorator. It didn't help.
  3. I'm inclined to think the issue is in the update method. The following code that constructs the solver but doesn't perform the updates doesn't have the memory leak:
import jax.numpy as jnp
import jaxopt
import jax
import gc
import time


def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj)
    state = solver.init_state(x0, min=mm)
    jitted_update = jax.jit(solver.update)
    params = x0
    for _ in range(solver.maxiter):
        pass
    #     params, state = jitted_update(params, state, min=mm)
    time.sleep(1)

for i in range(10):
    optimize(i)
    gc.collect()

However, if I uncomment the lines inside the for loop (even for just 1 iteration), the leak comes back

@froystig
Copy link
Member

In your example, is there still a leak if the update is not jitted?

@fabianp
Copy link
Collaborator

fabianp commented Feb 14, 2023

yeah, although there's a small decrease at the end that could mean it's recuperating some memory.

This is without jitting:
image

and with jitting:
image

As you can see, it's also using a lot more memory when it's not jitting. Not sure what to make of that

@alucantonio
Copy link
Author

Thanks for the investigations. I would like to know whether this behavior can be considered as a bug and whether there is any plan to fix it.

@fabianp
Copy link
Collaborator

fabianp commented Feb 23, 2023 via email

@alucantonio
Copy link
Author

Hi, has been there any progress on this issue?

@fabianp
Copy link
Collaborator

fabianp commented May 19, 2023

This behavior can be avoided using the newly implemented jax.clear_caches() in jax (thanks @froystig !).

For example, the code below doesn't have the ever increasing profile. Instead, it has the more expected initial increment and then plateau:

Figure_1

import jax.numpy as jnp
import jaxopt
import jax
import gc
import time


def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)
    jax.clear_caches()

@fabianp
Copy link
Collaborator

fabianp commented May 19, 2023

I'm going to close the issue for now, but please reopen if problem persist (BTW you might need the development version of jax for the clear_caches() function)

@fabianp fabianp closed this as completed May 19, 2023
@mblondel
Copy link
Collaborator

It's nice to have a workaround but shouldn't garbage collection be able to do this automatically?

@fabianp
Copy link
Collaborator

fabianp commented May 19, 2023 via email

@mblondel
Copy link
Collaborator

Agreed!

@fabianp
Copy link
Collaborator

fabianp commented May 22, 2023

@froystig made a good point in private conversation, that this might be symptomatic of jaxopt not using the cache properly and/or generating too many fresh functions instead of re-using the cache.

I don't have the bandwidth to look into it right now, but leaving open in case someone can look into it more deeply

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants