Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native pytree support in zoom line search. #260

Merged
merged 1 commit into from
Jun 28, 2022

Conversation

mblondel
Copy link
Collaborator

No description provided.

@mblondel mblondel requested a review from junpenglao June 28, 2022 09:00

# FIXME: directly accept a value_and_grad function to avoid recompilations.
phi, g = jax.value_and_grad(f)(xkp1)
dphi = tree_real(tree_vdot(g, pk))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often do we have this pattern? I am not sure how expensive is the additional tree map, maybe it is better to a single function that we map to the input?

real_vdot = lambda x, y: jnp.real(_vdot_safe(x, y))
tree_real_vdot = lambda x, y: tree_map(real_vdot, tree)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I wonder if JAX's jit is able to optimize this automatically. If not, happy to add the utility you suggest. BTW, we don't have good support for complex numbers in JAXopt, see #169.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I actually just realized that tree_real isn't necessary, just jnp.real will be enough, since tree_vdot returns a scalar.

@copybara-service copybara-service bot merged commit ed48fe9 into google:main Jun 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants