Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors with ScipyBoundedMinimize #403

Open
Smit-create opened this issue Feb 20, 2023 · 2 comments
Open

Errors with ScipyBoundedMinimize #403

Smit-create opened this issue Feb 20, 2023 · 2 comments

Comments

@Smit-create
Copy link

I tried the following:

from jaxopt import ScipyBoundedMinimize
solver = ScipyBoundedMinimize(fun=state_action_value_jax, method="l-bfgs-b")

def T_jax(v, model):
    def update_v(carry, y):
        b = jnp.array((1e-5, y))
        result = solver.run(y, bounds=b, data=(y, v, model)).params
        return carry + 1, (result.x, -result.fun, result.success)
    _, v_values = jax.lax.scan(update_v, 0, model.grid) 
    return v_values

This raises the following error while calling T_jax

[/usr/local/lib/python3.8/dist-packages/jaxopt/_src/scipy_wrappers.py](https://localhost:8080/#) in jnp_to_onp(x_jnp, dtype)
    116     determined by NumPy's casting rules for the concatenate method.
    117   """
--> 118   x_onp = [onp.asarray(leaf, dtype).reshape(-1)
    119            for leaf in tree_util.tree_leaves(x_jnp)]
    120   # NOTE(fllinares): return value must *not* be read-only, I believe.

[/usr/local/lib/python3.8/dist-packages/jaxopt/_src/scipy_wrappers.py](https://localhost:8080/#) in <listcomp>(.0)
    116     determined by NumPy's casting rules for the concatenate method.
    117   """
--> 118   x_onp = [onp.asarray(leaf, dtype).reshape(-1)
    119            for leaf in tree_util.tree_leaves(x_jnp)]
    120   # NOTE(fllinares): return value must *not* be read-only, I believe.

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/1)>
@Smit-create
Copy link
Author

Also, this line raised an error initially:

bounds = osp.optimize.Bounds(lb=jnp_to_onp(bounds[0], self.dtype),

which I fixed locally.

The error was:

AttributeError: module 'scipy' has no attribute 'optimize'

@mblondel
Copy link
Collaborator

mblondel commented Mar 1, 2023

Looks like you're trying to call ScipyMinimize from a jitted function, which is not currently supported. When #372 is done, it should work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants