Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 635420186
Change-Id: Ie71a2deb905622b947a9b075ce55bcb1bff46462
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 20, 2024
1 parent 8d4cc04 commit bea6d6b
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 46 deletions.
5 changes: 3 additions & 2 deletions acme/agents/jax/ars/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def apply(
normalized_obs = normalization_apply_fn(obs, normalization_params)
action = policy_network.apply(policy_params, normalized_obs)
return action, {
'params_key':
jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), params_key)
'params_key': jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), params_key
)
}

return apply
Expand Down
2 changes: 1 addition & 1 deletion acme/agents/jax/bc/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_variables(self, names: List[str]) -> List[networks_lib.Params]:

def save(self) -> TrainingState:
# Serialize only the first replica of parameters and optimizer state.
return jax.tree_map(utils.get_from_first_device, self._state)
return jax.tree.map(utils.get_from_first_device, self._state)

def restore(self, state: TrainingState):
self._state = utils.replicate_in_all_devices(state)
8 changes: 5 additions & 3 deletions acme/agents/jax/cql/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,11 @@ def update_step(
critic_grads, state.critic_optimizer_state)
critic_params = optax.apply_updates(state.critic_params, critic_update)

new_target_critic_params = jax.tree_map(
lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params,
critic_params)
new_target_critic_params = jax.tree.map(
lambda x, y: x * (1 - tau) + y * tau,
state.target_critic_params,
critic_params,
)

metrics = {
'critic_loss': critic_loss,
Expand Down
4 changes: 2 additions & 2 deletions acme/agents/jax/mbop/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ def get_normalization_stats(
"""
# Set up normalization:
example = next(iterator)
unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :], example)
unbatched_single_example = jax.tree.map(lambda x: x[0, PREVIOUS, :], example)
mean_std = running_statistics.init_state(unbatched_single_example)

for batch in itertools.islice(iterator, num_normalization_batches - 1):
example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch)
example = jax.tree.map(lambda x: x[:, PREVIOUS, :], batch)
mean_std = running_statistics.update(mean_std, example)

return mean_std
18 changes: 10 additions & 8 deletions acme/agents/jax/mbop/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,17 @@ def apply_round_robin(base_apply: Callable[[networks.Params, Any], Any],
num_networks = jax.tree_util.tree_leaves(params)[0].shape[0]

# Reshape args and kwargs for the round-robin:
args = jax.tree_map(
functools.partial(_split_batch_dimension, num_networks), args)
kwargs = jax.tree_map(
functools.partial(_split_batch_dimension, num_networks), kwargs)
args = jax.tree.map(
functools.partial(_split_batch_dimension, num_networks), args
)
kwargs = jax.tree.map(
functools.partial(_split_batch_dimension, num_networks), kwargs
)
# `out.shape` is `(num_networks, initial_batch_size/num_networks, ...)
out = jax.vmap(base_apply)(params, *args, **kwargs)
# Reshape to [initial_batch_size, <remaining dimensions>]. Using the 'F' order
# forces the original values to the last dimension.
return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out)
return jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out)


def apply_all(base_apply: Callable[[networks.Params, Any], Any],
Expand All @@ -133,8 +135,8 @@ def apply_all(base_apply: Callable[[networks.Params, Any], Any],
# `num_networks` is the size of the batch dimension in `params`.
num_networks = jax.tree_util.tree_leaves(params)[0].shape[0]

args = jax.tree_map(functools.partial(_repeat_n, num_networks), args)
kwargs = jax.tree_map(functools.partial(_repeat_n, num_networks), kwargs)
args = jax.tree.map(functools.partial(_repeat_n, num_networks), args)
kwargs = jax.tree.map(functools.partial(_repeat_n, num_networks), kwargs)
# `out` is of shape `(num_networks, batch_size, <remaining dimensions>)`.
return jax.vmap(base_apply)(params, *args, **kwargs)

Expand All @@ -155,7 +157,7 @@ def apply_mean(base_apply: Callable[[networks.Params, Any], Any],
Output shape will be [batch_size, <network output_dims>]
"""
out = apply_all(base_apply, params, *args, **kwargs)
return jax.tree_map(functools.partial(jnp.mean, axis=0), out)
return jax.tree.map(functools.partial(jnp.mean, axis=0), out)


def make_ensemble(base_network: networks.FeedForwardNetwork,
Expand Down
14 changes: 8 additions & 6 deletions acme/agents/jax/mbop/ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork:
"""Like params_adding_ffn, but with pytree inputs, preserves structure."""

def init_fn(key, sx=sx):
return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx)
return jax.tree.map(lambda x: jax.random.uniform(key, x.shape), sx)

def apply_fn(params, x):
return jax.tree_map(lambda p, v: p + v, params, x)
return jax.tree.map(lambda p, v: p + v, params, x)

return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn)

Expand Down Expand Up @@ -291,9 +291,10 @@ def test_round_robin_random(self):
for i in range(9):
np.testing.assert_allclose(
out[i],
ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]),
atol=1E-5,
rtol=1E-5)
ffn.apply(jax.tree.map(lambda p, i=i: p[i % 3], params), bx[i]),
atol=1e-5,
rtol=1e-5,
)

def test_mean_random(self):
x = jnp.ones(10)
Expand All @@ -318,7 +319,8 @@ def test_mean_random(self):
# Check results explicitly:
all_members = jnp.concatenate([
jnp.expand_dims(
ffn.apply(jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0)
ffn.apply(jax.tree.map(lambda p, i=i: p[i], params), bx), axis=0
)
for i in range(3)
])
batch_means = jnp.mean(all_members, axis=0)
Expand Down
20 changes: 12 additions & 8 deletions acme/agents/jax/mbop/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def world_model_loss(apply_fn: Callable[[networks.Observation, networks.Action],
Returns:
A scalar loss value as jnp.ndarray.
"""
observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...],
steps.observation)
observation_t = jax.tree.map(
lambda obs: obs[:, dataset.CURRENT, ...], steps.observation
)
action_t = steps.action[:, dataset.CURRENT, ...]
observation_tp1 = jax.tree_map(lambda obs: obs[:, dataset.NEXT, ...],
steps.observation)
observation_tp1 = jax.tree.map(
lambda obs: obs[:, dataset.NEXT, ...], steps.observation
)
reward_t = steps.reward[:, dataset.CURRENT, ...]
(predicted_observation_tp1,
predicted_reward_t) = apply_fn(observation_t, action_t)
Expand Down Expand Up @@ -86,8 +88,9 @@ def policy_prior_loss(
Returns:
A scalar loss value as jnp.ndarray.
"""
observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...],
steps.observation)
observation_t = jax.tree.map(
lambda obs: obs[:, dataset.CURRENT, ...], steps.observation
)
action_tm1 = steps.action[:, dataset.PREVIOUS, ...]
action_t = steps.action[:, dataset.CURRENT, ...]

Expand All @@ -109,8 +112,9 @@ def return_loss(apply_fn: Callable[[networks.Observation, networks.Action],
Returns:
A scalar loss value as jnp.ndarray.
"""
observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...],
steps.observation)
observation_t = jax.tree.map(
lambda obs: obs[:, dataset.CURRENT, ...], steps.observation
)
action_t = steps.action[:, dataset.CURRENT, ...]
n_step_return_t = steps.extras[dataset.N_STEP_RETURN][:, dataset.CURRENT, ...]

Expand Down
5 changes: 3 additions & 2 deletions acme/agents/jax/mbop/mppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,9 @@ def mppi_planner(
policy_prior_state = policy_prior.init(random_key)

# Broadcast so that we have n_trajectories copies of each:
observation_t = jax.tree_map(
functools.partial(_repeat_n, config.n_trajectories), observation)
observation_t = jax.tree.map(
functools.partial(_repeat_n, config.n_trajectories), observation
)
action_tm1 = jnp.broadcast_to(action_trajectory_tm1[0],
(config.n_trajectories,) +
action_trajectory_tm1[0].shape)
Expand Down
4 changes: 2 additions & 2 deletions acme/agents/jax/mpo/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def _sgd_step(
dual_params.log_penalty_temperature)
elif isinstance(dual_params, discrete_losses.CategoricalMPOParams):
dual_metrics['params/dual/log_alpha_avg'] = dual_params.log_alpha
metrics.update(jax.tree_map(jnp.mean, dual_metrics))
metrics.update(jax.tree.map(jnp.mean, dual_metrics))

return new_state, metrics

Expand Down Expand Up @@ -733,7 +733,7 @@ def get_variables(self, names: List[str]) -> network_lib.Params:
return [variables[name] for name in names]

def save(self) -> TrainingState:
return jax.tree_map(mpo_utils.get_from_first_device, self._state)
return jax.tree.map(mpo_utils.get_from_first_device, self._state)

def restore(self, state: TrainingState):
self._state = utils.replicate_in_all_devices(state, self._local_devices)
Expand Down
2 changes: 1 addition & 1 deletion acme/agents/jax/mpo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def critic_fn(observation: types.NestedArray,
def add_batch(nest, batch_size: Optional[int]):
"""Adds a batch dimension at axis 0 to the leaves of a nested structure."""
broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)
return jax.tree_map(broadcast, nest)
return jax.tree.map(broadcast, nest)


def w_init_identity(shape: Sequence[int], dtype) -> jnp.ndarray:
Expand Down
13 changes: 7 additions & 6 deletions acme/agents/jax/mpo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _slice_and_maybe_to_numpy(x):
x = x[0]
return _fetch_devicearray(x) if as_numpy else x

return jax.tree_map(_slice_and_maybe_to_numpy, nest)
return jax.tree.map(_slice_and_maybe_to_numpy, nest)


def rolling_window(x: jnp.ndarray,
Expand Down Expand Up @@ -80,11 +80,11 @@ def tree_map_distribution(
if isinstance(x, distrax.Distribution):
safe_f = lambda y: f(y) if isinstance(y, jnp.ndarray) else y
nil, tree_data = x.tree_flatten()
new_tree_data = jax.tree_map(safe_f, tree_data)
new_tree_data = jax.tree.map(safe_f, tree_data)
new_x = x.tree_unflatten(new_tree_data, nil)
return new_x
elif isinstance(x, tfd.Distribution):
return jax.tree_map(f, x)
return jax.tree.map(f, x)
else:
return f(x)

Expand All @@ -95,8 +95,9 @@ def make_sequences_from_transitions(
"""Convert a batch of transitions into a batch of 1-step sequences."""
stack = lambda x, y: jnp.stack((x, y), axis=num_batch_dims)
duplicate = lambda x: stack(x, x)
observation = jax.tree_map(stack, transitions.observation,
transitions.next_observation)
observation = jax.tree.map(
stack, transitions.observation, transitions.next_observation
)
reward = duplicate(transitions.reward)

return adders.Step( # pytype: disable=wrong-arg-types # jnp-type
Expand All @@ -105,5 +106,5 @@ def make_sequences_from_transitions(
reward=reward,
discount=duplicate(transitions.discount),
start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_),
extras=jax.tree_map(duplicate, transitions.extras),
extras=jax.tree.map(duplicate, transitions.extras),
)
4 changes: 2 additions & 2 deletions acme/agents/jax/r2d2/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ def loss(

# Maybe burn the core state in.
if burn_in_length:
burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation)
burn_obs = jax.tree.map(lambda x: x[:burn_in_length], data.observation)
key_grad, key1, key2 = jax.random.split(key_grad, 3)
_, online_state = networks.unroll(params, key1, burn_obs, online_state)
_, target_state = networks.unroll(target_params, key2, burn_obs,
target_state)

# Only get data to learn on from after the end of the burn in period.
data = jax.tree_map(lambda seq: seq[burn_in_length:], data)
data = jax.tree.map(lambda seq: seq[burn_in_length:], data)

# Unroll on sequences to get online and target Q-Values.
key1, key2 = jax.random.split(key_grad)
Expand Down
5 changes: 3 additions & 2 deletions acme/agents/jax/sac/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def update_step(
critic_grads, state.q_optimizer_state)
q_params = optax.apply_updates(state.q_params, critic_update)

new_target_q_params = jax.tree_map(lambda x, y: x * (1 - tau) + y * tau,
state.target_q_params, q_params)
new_target_q_params = jax.tree.map(
lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params
)

metrics = {
'critic_loss': critic_loss,
Expand Down
2 changes: 1 addition & 1 deletion acme/datasets/tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(self,
# we capture the whole dataset.
size = _dataset_size_upperbound(dataset)
data = next(dataset.batch(size).as_numpy_iterator())
self._dataset_size = jax.tree_flatten(
self._dataset_size = jax.tree.flatten(
jax.tree_util.tree_map(lambda x: x.shape[0], data)
)[0][0]
device = jax_utils._pmap_device_order()
Expand Down

0 comments on commit bea6d6b

Please sign in to comment.