Skip to content

Commit

Permalink
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
Browse files Browse the repository at this point in the history
The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461045645
  • Loading branch information
Jake VanderPlas authored and saran-t committed Jul 24, 2022
1 parent 11c2ab5 commit 956c4b5
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions kfac_ferminet_alpha/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ def multiply_matpower(
raise ValueError("Neither `has_scale` nor `has_shift`.")
factors = jax.tree_map(lambda x: x + diagonal_weight, factors)
if exp == 1:
return jax.tree_multimap(jnp.multiply, vec, factors)
return jax.tree_map(jnp.multiply, vec, factors)
elif exp == -1:
return jax.tree_multimap(jnp.divide, vec, factors)
return jax.tree_map(jnp.divide, vec, factors)
else:
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion kfac_ferminet_alpha/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def vec_block_apply(
"""Executes func for each approximation block on vectors."""
per_block_vectors = self.vectors_to_blocks(parameter_structured_vector)
assert len(per_block_vectors) == len(self.blocks)
results = jax.tree_multimap(func, tuple(self.blocks.values()),
results = jax.tree_map(func, tuple(self.blocks.values()),
per_block_vectors)
parameter_structured_result = self.blocks_to_vectors(results)
utils.check_structure_shapes_and_dtype(parameter_structured_vector,
Expand Down
4 changes: 2 additions & 2 deletions kfac_ferminet_alpha/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def _step(
)

# Update parameters: params = params + delta
params = jax.tree_multimap(jnp.add, params, delta)
params = jax.tree_map(jnp.add, params, delta)

# Optionally compute the reduction ratio and update the damping
self.estimator.damping = None
Expand Down Expand Up @@ -607,5 +607,5 @@ def velocities_and_delta(
assert len(vectors) == len(coefficients)
delta = utils.scalar_mul(vectors[0], coefficients[0])
for vi, wi in zip(vectors[1:], coefficients[1:]):
delta = jax.tree_multimap(jnp.add, delta, utils.scalar_mul(vi, wi))
delta = jax.tree_map(jnp.add, delta, utils.scalar_mul(vi, wi))
return delta, delta
2 changes: 1 addition & 1 deletion kfac_ferminet_alpha/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def extract_func_outputs(
def inner_product(obj1: T, obj2: T) -> jnp.ndarray:
if jax.tree_structure(obj1) != jax.tree_structure(obj2):
raise ValueError("The two structures are not identical.")
elements_product = jax.tree_multimap(lambda x, y: jnp.sum(x * y), obj1, obj2)
elements_product = jax.tree_map(lambda x, y: jnp.sum(x * y), obj1, obj2)
return sum(jax.tree_flatten(elements_product)[0])


Expand Down

0 comments on commit 956c4b5

Please sign in to comment.