Skip to content

Commit

Permalink
Use jax.tree_util.tree_map in place of deprecated tree_multimap.
Browse files Browse the repository at this point in the history
The latter is a simple alias of the former, so this change is a no-op.

PiperOrigin-RevId: 461229165
  • Loading branch information
Jake VanderPlas authored and saran-t committed Jul 24, 2022
1 parent 956c4b5 commit 6fcb842
Show file tree
Hide file tree
Showing 19 changed files with 42 additions and 42 deletions.
2 changes: 1 addition & 1 deletion adversarial_robustness/jax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def eval_epoch(self, params, state, rng):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars

Expand Down
2 changes: 1 addition & 1 deletion adversarial_robustness/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def ema_update(step: chex.Array,
def _weighted_average(p1, p2):
d = decay.astype(p1.dtype)
return (1 - d) * p1 + d * p2
return jax.tree_multimap(_weighted_average, new_params, avg_params)
return jax.tree_map(_weighted_average, new_params, avg_params)


def cutmix(rng: chex.PRNGKey,
Expand Down
6 changes: 3 additions & 3 deletions byol/byol_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _update_fn(

# cross-device grad and logs reductions
grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs)
logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs)

learning_rate = schedules.learning_schedule(
global_step,
Expand All @@ -339,7 +339,7 @@ def _update_fn(
global_step,
base_ema=self._base_target_ema,
max_steps=self._max_steps)
target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x),
target_params = jax.tree_map(lambda x, y: x + (1 - tau) * (y - x),
target_params, online_params)
logs['tau'] = tau
logs['learning_rate'] = learning_rate
Expand Down Expand Up @@ -518,7 +518,7 @@ def _eval_epoch(self, subset: Text, batch_size: int):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)

mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
Expand Down
2 changes: 1 addition & 1 deletion byol/eval_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def _eval_epoch(self, subset: Text, batch_size: int):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)

mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
10 changes: 5 additions & 5 deletions byol/utils/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray:
m = m.astype(g.dtype)
return g * (1. - m) + t * m

return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter)
return jax.tree_map(_update_fn, updates, new_updates, params_to_filter)


class ScaleByLarsState(NamedTuple):
Expand All @@ -78,7 +78,7 @@ def scale_by_lars(
"""

def init_fn(params: optax.Params) -> ScaleByLarsState:
mu = jax.tree_multimap(jnp.zeros_like, params) # momentum
mu = jax.tree_map(jnp.zeros_like, params) # momentum
return ScaleByLarsState(mu=mu)

def update_fn(updates: optax.Updates, state: ScaleByLarsState,
Expand All @@ -95,10 +95,10 @@ def lars_adaptation(
jnp.where(update_norm > 0,
(eta * param_norm / update_norm), 1.0), 1.0)

adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
adapted_updates = jax.tree_map(lars_adaptation, updates, params)
adapted_updates = _partial_update(updates, adapted_updates, params,
filter_fn)
mu = jax.tree_multimap(lambda g, t: momentum * g + t,
mu = jax.tree_map(lambda g, t: momentum * g + t,
state.mu, adapted_updates)
return mu, ScaleByLarsState(mu=mu)

Expand Down Expand Up @@ -130,7 +130,7 @@ def update_fn(
state: AddWeightDecayState,
params: optax.Params,
) -> Tuple[optax.Updates, AddWeightDecayState]:
new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
new_updates = jax.tree_map(lambda g, p: g + weight_decay * p, updates,
params)
new_updates = _partial_update(updates, new_updates, params, filter_fn)
return new_updates, state
Expand Down
2 changes: 1 addition & 1 deletion nfnets/agc_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def update_fn(updates, state, params):
# Maximum allowable norm
max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm)
# If grad norm > clipping * param_norm, rescale
updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates)
updates = jax.tree_map(my_clip, g_norm, max_norm, updates)
return updates, state

return optax.GradientTransformation(init_fn, update_fn)
6 changes: 3 additions & 3 deletions nfnets/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def _train_fn(self, params, states, opt_states, inputs, rng, global_step,
if ema_params is not None:
ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema'))
ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_states = jax.tree_multimap(ema, ema_states, states)
ema_params = jax.tree_map(ema, ema_params, params)
ema_states = jax.tree_map(ema, ema_states, states)
return {
'params': params,
'states': states,
Expand Down Expand Up @@ -354,7 +354,7 @@ def _eval_epoch(self, params, state):
if summed_metrics is None:
summed_metrics = metrics
else:
summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics)
summed_metrics = jax.tree_map(jnp.add, summed_metrics, metrics)
mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics)
return jax.device_get(mean_metrics)

Expand Down
2 changes: 1 addition & 1 deletion ogb_lsc/mag/batching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:

def _map_concat(nests):
concat = lambda *args: np.concatenate(args)
return tree.tree_multimap(concat, *nests)
return tree.tree_map(concat, *nests)

return jraph.GraphsTuple(
n_node=np.concatenate([g.n_node for g in graphs]),
Expand Down
4 changes: 2 additions & 2 deletions ogb_lsc/mag/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,8 @@ def _update_func(
))
ema_fn = (lambda x, y: # pylint:disable=g-long-lambda
schedules.apply_ema_decay(x, y, ema_rate))
ema_params = jax.tree_multimap(ema_fn, ema_params, params)
ema_network_state = jax.tree_multimap(
ema_params = jax.tree_map(ema_fn, ema_params, params)
ema_network_state = jax.tree_map(
ema_fn,
ema_network_state,
network_state,
Expand Down
2 changes: 1 addition & 1 deletion ogb_lsc/pcq/batching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:

def _map_concat(nests):
concat = lambda *args: np.concatenate(args)
return tree.tree_multimap(concat, *nests)
return tree.tree_map(concat, *nests)

return jraph.GraphsTuple(
n_node=np.concatenate([g.n_node for g in graphs]),
Expand Down
4 changes: 2 additions & 2 deletions ogb_lsc/pcq/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def get_loss(*x, **graph):
params = optax.apply_updates(params, updates)
if ema_params is not None:
ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_network_state = jax.tree_multimap(ema, ema_network_state,
ema_params = jax.tree_map(ema, ema_params, params)
ema_network_state = jax.tree_map(ema, ema_network_state,
network_state)
return params, ema_params, network_state, ema_network_state, opt_state, scalars

Expand Down
2 changes: 1 addition & 1 deletion perceiver/train/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def _eval_epoch(self, rng):
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)

mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
return mean_scalars
Expand Down
2 changes: 1 addition & 1 deletion perceiver/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def update_fn(updates, state, params):

u_ex, u_in = hk.data_structures.partition(exclude, updates)
_, p_in = hk.data_structures.partition(exclude, params)
u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
updates = hk.data_structures.merge(u_ex, u_in)
return updates, state

Expand Down
4 changes: 2 additions & 2 deletions physics_inspired_models/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def loop_body(y_: M, t_dt: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[M, M]:
for t_and_dt_i in zip(t, dt):
y.append(loop_body(y[-1], t_and_dt_i)[0])
# Note that we do not return the initial point
return t_eval, jax.tree_multimap(lambda *args: jnp.stack(args, axis=0),
return t_eval, jax.tree_map(lambda *args: jnp.stack(args, axis=0),
*y[1:])


Expand Down Expand Up @@ -252,7 +252,7 @@ def solve_ivp_dt_two_directions(
)[1]
yt.append(yt_fwd)
if len(yt) > 1:
return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yt)
return jax.tree_map(lambda *a: jnp.concatenate(a, axis=0), *yt)
else:
return yt[0]

Expand Down
2 changes: 1 addition & 1 deletion physics_inspired_models/jaxline_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _jax_burnin_fn(self, params, state, rng_key, batch):
new_state = utils.pmean_if_pmap(new_state, axis_name="i")
new_state = hk.data_structures.to_mutable_dict(new_state)
new_state = hk.data_structures.to_immutable_dict(new_state)
return jax.tree_multimap(jnp.add, new_state, state)
return jax.tree_map(jnp.add, new_state, state)

# _
# _____ ____ _| |
Expand Down
2 changes: 1 addition & 1 deletion physics_inspired_models/models/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def step(*args):
if len(yt) == 1:
yt = yt[0][:, None]
else:
yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt)
yt = jax.tree_map(lambda args: jnp.stack(args, 1), yt)
if return_stats:
return yt, dict()
else:
Expand Down
8 changes: 4 additions & 4 deletions physics_inspired_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def add(self, averaged_values, num_samples):
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
self._num_samples = num_samples
else:
self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max,
self._obj_max = jax.tree_map(jnp.maximum, self._obj_max,
averaged_values)
self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min,
self._obj_min = jax.tree_map(jnp.minimum, self._obj_min,
averaged_values)
self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj,
self._obj = jax.tree_map(lambda x, y: x + y * num_samples, self._obj,
averaged_values)
self._num_samples += num_samples

Expand All @@ -249,7 +249,7 @@ def sum(self):


def inner_product(x: Any, y: Any) -> jnp.ndarray:
products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y)
products = jax.tree_map(lambda x_, y_: jnp.sum(x_ * y_), x, y)
return sum(jax.tree_leaves(products))


Expand Down
2 changes: 1 addition & 1 deletion tandem_dqn/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def step(
# 0 * NaN == NaN just replace self._statistics on the first step.
self._statistics = dict(agent.statistics)
else:
self._statistics = jax.tree_multimap(
self._statistics = jax.tree_map(
lambda s, x: (1 - final_step_size) * s + final_step_size * x,
self._statistics, agent.statistics)

Expand Down
20 changes: 10 additions & 10 deletions wikigraphs/updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import haiku as hk
import jax
from jax.tree_util import tree_multimap
from jax.tree_util import tree_map
import numpy as np
import optax

Expand Down Expand Up @@ -185,7 +185,7 @@ def update(self, state, data):
'state', 'params', 'rng', 'replicated_step', 'opt_state']))
state.update(extra_state)
state['step'] += 1
return state, tree_multimap(lambda x: x[0], out)
return state, tree_map(lambda x: x[0], out)

def _eval(self, state, data, max_graph_size=None):
"""Evaluates the current state on the given data."""
Expand All @@ -209,7 +209,7 @@ def eval_return_state(self, state, data):
self._eval_fn, state, [data], keys=set([
'state', 'params', 'rng', 'replicated_step']))
state.update(extra_state)
return state, tree_multimap(lambda x: x[0], out)
return state, tree_map(lambda x: x[0], out)

def eval(self, state, data):
"""Returns metrics without updating the model."""
Expand All @@ -230,35 +230,35 @@ def add_core_dimension(x):
prefix = (self._num_devices, x.shape[0] // self._num_devices)
return np.reshape(x, prefix + x.shape[1:])

multi_inputs = tree_multimap(add_core_dimension, multi_inputs)
multi_inputs = tree_map(add_core_dimension, multi_inputs)
return multi_inputs

def params(self, state):
"""Returns model parameters."""
return tree_multimap(lambda x: x[0], state['params'])
return tree_map(lambda x: x[0], state['params'])

def opt_state(self, state):
"""Returns the state of the optimiser."""
return tree_multimap(lambda x: x[0], state['opt_state'])
return tree_map(lambda x: x[0], state['opt_state'])

def network_state(self, state):
"""Returns the model's state."""
return tree_multimap(lambda x: x[0], state['state'])
return tree_map(lambda x: x[0], state['state'])

def to_checkpoint_state(self, state):
"""Transforms the updater state into a checkpointable state."""
checkpoint_state = state.copy()
# Wrapper around checkpoint_state['step'] so we can get [0].
checkpoint_state['step'] = checkpoint_state['step'][np.newaxis]
# Unstack the replicated contents.
checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state)
checkpoint_state = tree_map(lambda x: x[0], checkpoint_state)
return checkpoint_state

def from_checkpoint_state(self, checkpoint_state):
"""Initializes the updater state from the checkpointed state."""
# Expand the checkpoint so we have a copy for each device.
state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]),
checkpoint_state)
state['step'] = state['step'][0] # Undo stacking for step.
return state

Expand Down

0 comments on commit 6fcb842

Please sign in to comment.