diff --git a/adversarial_robustness/jax/experiment.py b/adversarial_robustness/jax/experiment.py index 3b87faf7..d85eb31d 100644 --- a/adversarial_robustness/jax/experiment.py +++ b/adversarial_robustness/jax/experiment.py @@ -373,7 +373,7 @@ def eval_epoch(self, params, state, rng): if summed_scalars is None: summed_scalars = scalars else: - summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars) mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) return mean_scalars diff --git a/adversarial_robustness/jax/utils.py b/adversarial_robustness/jax/utils.py index 00728b0c..2545d5f0 100644 --- a/adversarial_robustness/jax/utils.py +++ b/adversarial_robustness/jax/utils.py @@ -124,7 +124,7 @@ def ema_update(step: chex.Array, def _weighted_average(p1, p2): d = decay.astype(p1.dtype) return (1 - d) * p1 + d * p2 - return jax.tree_multimap(_weighted_average, new_params, avg_params) + return jax.tree_map(_weighted_average, new_params, avg_params) def cutmix(rng: chex.PRNGKey, diff --git a/byol/byol_experiment.py b/byol/byol_experiment.py index d6b76aed..4494e354 100644 --- a/byol/byol_experiment.py +++ b/byol/byol_experiment.py @@ -323,7 +323,7 @@ def _update_fn( # cross-device grad and logs reductions grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads) - logs = jax.tree_multimap(lambda x: jax.lax.pmean(x, axis_name='i'), logs) + logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs) learning_rate = schedules.learning_schedule( global_step, @@ -339,7 +339,7 @@ def _update_fn( global_step, base_ema=self._base_target_ema, max_steps=self._max_steps) - target_params = jax.tree_multimap(lambda x, y: x + (1 - tau) * (y - x), + target_params = jax.tree_map(lambda x, y: x + (1 - tau) * (y - x), target_params, online_params) logs['tau'] = tau logs['learning_rate'] = learning_rate @@ -518,7 +518,7 @@ def _eval_epoch(self, subset: Text, batch_size: int): if summed_scalars is None: summed_scalars = scalars else: - summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars) mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) return mean_scalars diff --git a/byol/eval_experiment.py b/byol/eval_experiment.py index 83b4fbf7..92d54ff4 100644 --- a/byol/eval_experiment.py +++ b/byol/eval_experiment.py @@ -471,7 +471,7 @@ def _eval_epoch(self, subset: Text, batch_size: int): if summed_scalars is None: summed_scalars = scalars else: - summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars) mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) return mean_scalars diff --git a/byol/utils/optimizers.py b/byol/utils/optimizers.py index f80423c7..caade06a 100644 --- a/byol/utils/optimizers.py +++ b/byol/utils/optimizers.py @@ -51,7 +51,7 @@ def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray: m = m.astype(g.dtype) return g * (1. - m) + t * m - return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter) + return jax.tree_map(_update_fn, updates, new_updates, params_to_filter) class ScaleByLarsState(NamedTuple): @@ -78,7 +78,7 @@ def scale_by_lars( """ def init_fn(params: optax.Params) -> ScaleByLarsState: - mu = jax.tree_multimap(jnp.zeros_like, params) # momentum + mu = jax.tree_map(jnp.zeros_like, params) # momentum return ScaleByLarsState(mu=mu) def update_fn(updates: optax.Updates, state: ScaleByLarsState, @@ -95,10 +95,10 @@ def lars_adaptation( jnp.where(update_norm > 0, (eta * param_norm / update_norm), 1.0), 1.0) - adapted_updates = jax.tree_multimap(lars_adaptation, updates, params) + adapted_updates = jax.tree_map(lars_adaptation, updates, params) adapted_updates = _partial_update(updates, adapted_updates, params, filter_fn) - mu = jax.tree_multimap(lambda g, t: momentum * g + t, + mu = jax.tree_map(lambda g, t: momentum * g + t, state.mu, adapted_updates) return mu, ScaleByLarsState(mu=mu) @@ -130,7 +130,7 @@ def update_fn( state: AddWeightDecayState, params: optax.Params, ) -> Tuple[optax.Updates, AddWeightDecayState]: - new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, + new_updates = jax.tree_map(lambda g, p: g + weight_decay * p, updates, params) new_updates = _partial_update(updates, new_updates, params, filter_fn) return new_updates, state diff --git a/nfnets/agc_optax.py b/nfnets/agc_optax.py index 2b435c02..00ecf41b 100644 --- a/nfnets/agc_optax.py +++ b/nfnets/agc_optax.py @@ -72,7 +72,7 @@ def update_fn(updates, state, params): # Maximum allowable norm max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm) # If grad norm > clipping * param_norm, rescale - updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates) + updates = jax.tree_map(my_clip, g_norm, max_norm, updates) return updates, state return optax.GradientTransformation(init_fn, update_fn) diff --git a/nfnets/experiment.py b/nfnets/experiment.py index 0cf137e8..b2984256 100644 --- a/nfnets/experiment.py +++ b/nfnets/experiment.py @@ -270,8 +270,8 @@ def _train_fn(self, params, states, opt_states, inputs, rng, global_step, if ema_params is not None: ema_fn = getattr(utils, self.config.get('which_ema', 'tf1_ema')) ema = lambda x, y: ema_fn(x, y, self.config.ema_decay, global_step) - ema_params = jax.tree_multimap(ema, ema_params, params) - ema_states = jax.tree_multimap(ema, ema_states, states) + ema_params = jax.tree_map(ema, ema_params, params) + ema_states = jax.tree_map(ema, ema_states, states) return { 'params': params, 'states': states, @@ -354,7 +354,7 @@ def _eval_epoch(self, params, state): if summed_metrics is None: summed_metrics = metrics else: - summed_metrics = jax.tree_multimap(jnp.add, summed_metrics, metrics) + summed_metrics = jax.tree_map(jnp.add, summed_metrics, metrics) mean_metrics = jax.tree_map(lambda x: x / num_samples, summed_metrics) return jax.device_get(mean_metrics) diff --git a/ogb_lsc/mag/batching_utils.py b/ogb_lsc/mag/batching_utils.py index 88392055..b3da1fc5 100644 --- a/ogb_lsc/mag/batching_utils.py +++ b/ogb_lsc/mag/batching_utils.py @@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple: def _map_concat(nests): concat = lambda *args: np.concatenate(args) - return tree.tree_multimap(concat, *nests) + return tree.tree_map(concat, *nests) return jraph.GraphsTuple( n_node=np.concatenate([g.n_node for g in graphs]), diff --git a/ogb_lsc/mag/experiment.py b/ogb_lsc/mag/experiment.py index 7f9e6102..02f86545 100644 --- a/ogb_lsc/mag/experiment.py +++ b/ogb_lsc/mag/experiment.py @@ -487,8 +487,8 @@ def _update_func( )) ema_fn = (lambda x, y: # pylint:disable=g-long-lambda schedules.apply_ema_decay(x, y, ema_rate)) - ema_params = jax.tree_multimap(ema_fn, ema_params, params) - ema_network_state = jax.tree_multimap( + ema_params = jax.tree_map(ema_fn, ema_params, params) + ema_network_state = jax.tree_map( ema_fn, ema_network_state, network_state, diff --git a/ogb_lsc/pcq/batching_utils.py b/ogb_lsc/pcq/batching_utils.py index b06dcbda..a7e316bf 100644 --- a/ogb_lsc/pcq/batching_utils.py +++ b/ogb_lsc/pcq/batching_utils.py @@ -107,7 +107,7 @@ def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple: def _map_concat(nests): concat = lambda *args: np.concatenate(args) - return tree.tree_multimap(concat, *nests) + return tree.tree_map(concat, *nests) return jraph.GraphsTuple( n_node=np.concatenate([g.n_node for g in graphs]), diff --git a/ogb_lsc/pcq/experiment.py b/ogb_lsc/pcq/experiment.py index 5e66a868..c9305617 100644 --- a/ogb_lsc/pcq/experiment.py +++ b/ogb_lsc/pcq/experiment.py @@ -229,8 +229,8 @@ def get_loss(*x, **graph): params = optax.apply_updates(params, updates) if ema_params is not None: ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step) - ema_params = jax.tree_multimap(ema, ema_params, params) - ema_network_state = jax.tree_multimap(ema, ema_network_state, + ema_params = jax.tree_map(ema, ema_params, params) + ema_network_state = jax.tree_map(ema, ema_network_state, network_state) return params, ema_params, network_state, ema_network_state, opt_state, scalars diff --git a/perceiver/train/experiment.py b/perceiver/train/experiment.py index 591a3bb1..ce06787f 100644 --- a/perceiver/train/experiment.py +++ b/perceiver/train/experiment.py @@ -527,7 +527,7 @@ def _eval_epoch(self, rng): if summed_scalars is None: summed_scalars = scalars else: - summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars) + summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars) mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars) return mean_scalars diff --git a/perceiver/train/utils.py b/perceiver/train/utils.py index 9935101d..84c9cef2 100644 --- a/perceiver/train/utils.py +++ b/perceiver/train/utils.py @@ -192,7 +192,7 @@ def update_fn(updates, state, params): u_ex, u_in = hk.data_structures.partition(exclude, updates) _, p_in = hk.data_structures.partition(exclude, params) - u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in) + u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in) updates = hk.data_structures.merge(u_ex, u_in) return updates, state diff --git a/physics_inspired_models/integrators.py b/physics_inspired_models/integrators.py index c41719f3..ef1a9e5f 100644 --- a/physics_inspired_models/integrators.py +++ b/physics_inspired_models/integrators.py @@ -204,7 +204,7 @@ def loop_body(y_: M, t_dt: Tuple[jnp.ndarray, jnp.ndarray]) -> Tuple[M, M]: for t_and_dt_i in zip(t, dt): y.append(loop_body(y[-1], t_and_dt_i)[0]) # Note that we do not return the initial point - return t_eval, jax.tree_multimap(lambda *args: jnp.stack(args, axis=0), + return t_eval, jax.tree_map(lambda *args: jnp.stack(args, axis=0), *y[1:]) @@ -252,7 +252,7 @@ def solve_ivp_dt_two_directions( )[1] yt.append(yt_fwd) if len(yt) > 1: - return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yt) + return jax.tree_map(lambda *a: jnp.concatenate(a, axis=0), *yt) else: return yt[0] diff --git a/physics_inspired_models/jaxline_train.py b/physics_inspired_models/jaxline_train.py index 3b736ee8..4616ea1c 100644 --- a/physics_inspired_models/jaxline_train.py +++ b/physics_inspired_models/jaxline_train.py @@ -192,7 +192,7 @@ def _jax_burnin_fn(self, params, state, rng_key, batch): new_state = utils.pmean_if_pmap(new_state, axis_name="i") new_state = hk.data_structures.to_mutable_dict(new_state) new_state = hk.data_structures.to_immutable_dict(new_state) - return jax.tree_multimap(jnp.add, new_state, state) + return jax.tree_map(jnp.add, new_state, state) # _ # _____ ____ _| | diff --git a/physics_inspired_models/models/dynamics.py b/physics_inspired_models/models/dynamics.py index 25e22b4f..d5e08ca3 100644 --- a/physics_inspired_models/models/dynamics.py +++ b/physics_inspired_models/models/dynamics.py @@ -829,7 +829,7 @@ def step(*args): if len(yt) == 1: yt = yt[0][:, None] else: - yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt) + yt = jax.tree_map(lambda args: jnp.stack(args, 1), yt) if return_stats: return yt, dict() else: diff --git a/physics_inspired_models/utils.py b/physics_inspired_models/utils.py index bde4af49..7d2aad11 100644 --- a/physics_inspired_models/utils.py +++ b/physics_inspired_models/utils.py @@ -220,11 +220,11 @@ def add(self, averaged_values, num_samples): self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values) self._num_samples = num_samples else: - self._obj_max = jax.tree_multimap(jnp.maximum, self._obj_max, + self._obj_max = jax.tree_map(jnp.maximum, self._obj_max, averaged_values) - self._obj_min = jax.tree_multimap(jnp.minimum, self._obj_min, + self._obj_min = jax.tree_map(jnp.minimum, self._obj_min, averaged_values) - self._obj = jax.tree_multimap(lambda x, y: x + y * num_samples, self._obj, + self._obj = jax.tree_map(lambda x, y: x + y * num_samples, self._obj, averaged_values) self._num_samples += num_samples @@ -249,7 +249,7 @@ def sum(self): def inner_product(x: Any, y: Any) -> jnp.ndarray: - products = jax.tree_multimap(lambda x_, y_: jnp.sum(x_ * y_), x, y) + products = jax.tree_map(lambda x_, y_: jnp.sum(x_ * y_), x, y) return sum(jax.tree_leaves(products)) diff --git a/tandem_dqn/parts.py b/tandem_dqn/parts.py index 57d29e1d..9038956d 100644 --- a/tandem_dqn/parts.py +++ b/tandem_dqn/parts.py @@ -277,7 +277,7 @@ def step( # 0 * NaN == NaN just replace self._statistics on the first step. self._statistics = dict(agent.statistics) else: - self._statistics = jax.tree_multimap( + self._statistics = jax.tree_map( lambda s, x: (1 - final_step_size) * s + final_step_size * x, self._statistics, agent.statistics) diff --git a/wikigraphs/updaters.py b/wikigraphs/updaters.py index 8bb258fe..4e64e6a5 100644 --- a/wikigraphs/updaters.py +++ b/wikigraphs/updaters.py @@ -37,7 +37,7 @@ import haiku as hk import jax -from jax.tree_util import tree_multimap +from jax.tree_util import tree_map import numpy as np import optax @@ -185,7 +185,7 @@ def update(self, state, data): 'state', 'params', 'rng', 'replicated_step', 'opt_state'])) state.update(extra_state) state['step'] += 1 - return state, tree_multimap(lambda x: x[0], out) + return state, tree_map(lambda x: x[0], out) def _eval(self, state, data, max_graph_size=None): """Evaluates the current state on the given data.""" @@ -209,7 +209,7 @@ def eval_return_state(self, state, data): self._eval_fn, state, [data], keys=set([ 'state', 'params', 'rng', 'replicated_step'])) state.update(extra_state) - return state, tree_multimap(lambda x: x[0], out) + return state, tree_map(lambda x: x[0], out) def eval(self, state, data): """Returns metrics without updating the model.""" @@ -230,20 +230,20 @@ def add_core_dimension(x): prefix = (self._num_devices, x.shape[0] // self._num_devices) return np.reshape(x, prefix + x.shape[1:]) - multi_inputs = tree_multimap(add_core_dimension, multi_inputs) + multi_inputs = tree_map(add_core_dimension, multi_inputs) return multi_inputs def params(self, state): """Returns model parameters.""" - return tree_multimap(lambda x: x[0], state['params']) + return tree_map(lambda x: x[0], state['params']) def opt_state(self, state): """Returns the state of the optimiser.""" - return tree_multimap(lambda x: x[0], state['opt_state']) + return tree_map(lambda x: x[0], state['opt_state']) def network_state(self, state): """Returns the model's state.""" - return tree_multimap(lambda x: x[0], state['state']) + return tree_map(lambda x: x[0], state['state']) def to_checkpoint_state(self, state): """Transforms the updater state into a checkpointable state.""" @@ -251,14 +251,14 @@ def to_checkpoint_state(self, state): # Wrapper around checkpoint_state['step'] so we can get [0]. checkpoint_state['step'] = checkpoint_state['step'][np.newaxis] # Unstack the replicated contents. - checkpoint_state = tree_multimap(lambda x: x[0], checkpoint_state) + checkpoint_state = tree_map(lambda x: x[0], checkpoint_state) return checkpoint_state def from_checkpoint_state(self, checkpoint_state): """Initializes the updater state from the checkpointed state.""" # Expand the checkpoint so we have a copy for each device. - state = tree_multimap(lambda x: np.stack(jax.local_device_count() * [x]), - checkpoint_state) + state = tree_map(lambda x: np.stack(jax.local_device_count() * [x]), + checkpoint_state) state['step'] = state['step'][0] # Undo stacking for step. return state