Skip to content

Commit bea6d6b

Browse files
Jake VanderPlascopybara-github
Jake VanderPlas
authored andcommitted
Replace deprecated jax.tree_* functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 635420186 Change-Id: Ie71a2deb905622b947a9b075ce55bcb1bff46462
1 parent 8d4cc04 commit bea6d6b

14 files changed

+60
-46
lines changed

acme/agents/jax/ars/builder.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def apply(
5050
normalized_obs = normalization_apply_fn(obs, normalization_params)
5151
action = policy_network.apply(policy_params, normalized_obs)
5252
return action, {
53-
'params_key':
54-
jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), params_key)
53+
'params_key': jax.tree.map(
54+
lambda x: jnp.expand_dims(x, axis=0), params_key
55+
)
5556
}
5657

5758
return apply

acme/agents/jax/bc/learning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def get_variables(self, names: List[str]) -> List[networks_lib.Params]:
194194

195195
def save(self) -> TrainingState:
196196
# Serialize only the first replica of parameters and optimizer state.
197-
return jax.tree_map(utils.get_from_first_device, self._state)
197+
return jax.tree.map(utils.get_from_first_device, self._state)
198198

199199
def restore(self, state: TrainingState):
200200
self._state = utils.replicate_in_all_devices(state)

acme/agents/jax/cql/learning.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,11 @@ def update_step(
337337
critic_grads, state.critic_optimizer_state)
338338
critic_params = optax.apply_updates(state.critic_params, critic_update)
339339

340-
new_target_critic_params = jax.tree_map(
341-
lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params,
342-
critic_params)
340+
new_target_critic_params = jax.tree.map(
341+
lambda x, y: x * (1 - tau) + y * tau,
342+
state.target_critic_params,
343+
critic_params,
344+
)
343345

344346
metrics = {
345347
'critic_loss': critic_loss,

acme/agents/jax/mbop/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,11 @@ def get_normalization_stats(
210210
"""
211211
# Set up normalization:
212212
example = next(iterator)
213-
unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :], example)
213+
unbatched_single_example = jax.tree.map(lambda x: x[0, PREVIOUS, :], example)
214214
mean_std = running_statistics.init_state(unbatched_single_example)
215215

216216
for batch in itertools.islice(iterator, num_normalization_batches - 1):
217-
example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch)
217+
example = jax.tree.map(lambda x: x[:, PREVIOUS, :], batch)
218218
mean_std = running_statistics.update(mean_std, example)
219219

220220
return mean_std

acme/agents/jax/mbop/ensemble.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,17 @@ def apply_round_robin(base_apply: Callable[[networks.Params, Any], Any],
100100
num_networks = jax.tree_util.tree_leaves(params)[0].shape[0]
101101

102102
# Reshape args and kwargs for the round-robin:
103-
args = jax.tree_map(
104-
functools.partial(_split_batch_dimension, num_networks), args)
105-
kwargs = jax.tree_map(
106-
functools.partial(_split_batch_dimension, num_networks), kwargs)
103+
args = jax.tree.map(
104+
functools.partial(_split_batch_dimension, num_networks), args
105+
)
106+
kwargs = jax.tree.map(
107+
functools.partial(_split_batch_dimension, num_networks), kwargs
108+
)
107109
# `out.shape` is `(num_networks, initial_batch_size/num_networks, ...)
108110
out = jax.vmap(base_apply)(params, *args, **kwargs)
109111
# Reshape to [initial_batch_size, <remaining dimensions>]. Using the 'F' order
110112
# forces the original values to the last dimension.
111-
return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out)
113+
return jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out)
112114

113115

114116
def apply_all(base_apply: Callable[[networks.Params, Any], Any],
@@ -133,8 +135,8 @@ def apply_all(base_apply: Callable[[networks.Params, Any], Any],
133135
# `num_networks` is the size of the batch dimension in `params`.
134136
num_networks = jax.tree_util.tree_leaves(params)[0].shape[0]
135137

136-
args = jax.tree_map(functools.partial(_repeat_n, num_networks), args)
137-
kwargs = jax.tree_map(functools.partial(_repeat_n, num_networks), kwargs)
138+
args = jax.tree.map(functools.partial(_repeat_n, num_networks), args)
139+
kwargs = jax.tree.map(functools.partial(_repeat_n, num_networks), kwargs)
138140
# `out` is of shape `(num_networks, batch_size, <remaining dimensions>)`.
139141
return jax.vmap(base_apply)(params, *args, **kwargs)
140142

@@ -155,7 +157,7 @@ def apply_mean(base_apply: Callable[[networks.Params, Any], Any],
155157
Output shape will be [batch_size, <network output_dims>]
156158
"""
157159
out = apply_all(base_apply, params, *args, **kwargs)
158-
return jax.tree_map(functools.partial(jnp.mean, axis=0), out)
160+
return jax.tree.map(functools.partial(jnp.mean, axis=0), out)
159161

160162

161163
def make_ensemble(base_network: networks.FeedForwardNetwork,

acme/agents/jax/mbop/ensemble_test.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork:
5252
"""Like params_adding_ffn, but with pytree inputs, preserves structure."""
5353

5454
def init_fn(key, sx=sx):
55-
return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx)
55+
return jax.tree.map(lambda x: jax.random.uniform(key, x.shape), sx)
5656

5757
def apply_fn(params, x):
58-
return jax.tree_map(lambda p, v: p + v, params, x)
58+
return jax.tree.map(lambda p, v: p + v, params, x)
5959

6060
return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn)
6161

@@ -291,9 +291,10 @@ def test_round_robin_random(self):
291291
for i in range(9):
292292
np.testing.assert_allclose(
293293
out[i],
294-
ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]),
295-
atol=1E-5,
296-
rtol=1E-5)
294+
ffn.apply(jax.tree.map(lambda p, i=i: p[i % 3], params), bx[i]),
295+
atol=1e-5,
296+
rtol=1e-5,
297+
)
297298

298299
def test_mean_random(self):
299300
x = jnp.ones(10)
@@ -318,7 +319,8 @@ def test_mean_random(self):
318319
# Check results explicitly:
319320
all_members = jnp.concatenate([
320321
jnp.expand_dims(
321-
ffn.apply(jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0)
322+
ffn.apply(jax.tree.map(lambda p, i=i: p[i], params), bx), axis=0
323+
)
322324
for i in range(3)
323325
])
324326
batch_means = jnp.mean(all_members, axis=0)

acme/agents/jax/mbop/losses.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,13 @@ def world_model_loss(apply_fn: Callable[[networks.Observation, networks.Action],
5353
Returns:
5454
A scalar loss value as jnp.ndarray.
5555
"""
56-
observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...],
57-
steps.observation)
56+
observation_t = jax.tree.map(
57+
lambda obs: obs[:, dataset.CURRENT, ...], steps.observation
58+
)
5859
action_t = steps.action[:, dataset.CURRENT, ...]
59-
observation_tp1 = jax.tree_map(lambda obs: obs[:, dataset.NEXT, ...],
60-
steps.observation)
60+
observation_tp1 = jax.tree.map(
61+
lambda obs: obs[:, dataset.NEXT, ...], steps.observation
62+
)
6163
reward_t = steps.reward[:, dataset.CURRENT, ...]
6264
(predicted_observation_tp1,
6365
predicted_reward_t) = apply_fn(observation_t, action_t)
@@ -86,8 +88,9 @@ def policy_prior_loss(
8688
Returns:
8789
A scalar loss value as jnp.ndarray.
8890
"""
89-
observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...],
90-
steps.observation)
91+
observation_t = jax.tree.map(
92+
lambda obs: obs[:, dataset.CURRENT, ...], steps.observation
93+
)
9194
action_tm1 = steps.action[:, dataset.PREVIOUS, ...]
9295
action_t = steps.action[:, dataset.CURRENT, ...]
9396

@@ -109,8 +112,9 @@ def return_loss(apply_fn: Callable[[networks.Observation, networks.Action],
109112
Returns:
110113
A scalar loss value as jnp.ndarray.
111114
"""
112-
observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...],
113-
steps.observation)
115+
observation_t = jax.tree.map(
116+
lambda obs: obs[:, dataset.CURRENT, ...], steps.observation
117+
)
114118
action_t = steps.action[:, dataset.CURRENT, ...]
115119
n_step_return_t = steps.extras[dataset.N_STEP_RETURN][:, dataset.CURRENT, ...]
116120

acme/agents/jax/mbop/mppi.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,9 @@ def mppi_planner(
183183
policy_prior_state = policy_prior.init(random_key)
184184

185185
# Broadcast so that we have n_trajectories copies of each:
186-
observation_t = jax.tree_map(
187-
functools.partial(_repeat_n, config.n_trajectories), observation)
186+
observation_t = jax.tree.map(
187+
functools.partial(_repeat_n, config.n_trajectories), observation
188+
)
188189
action_tm1 = jnp.broadcast_to(action_trajectory_tm1[0],
189190
(config.n_trajectories,) +
190191
action_trajectory_tm1[0].shape)

acme/agents/jax/mpo/learning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def _sgd_step(
681681
dual_params.log_penalty_temperature)
682682
elif isinstance(dual_params, discrete_losses.CategoricalMPOParams):
683683
dual_metrics['params/dual/log_alpha_avg'] = dual_params.log_alpha
684-
metrics.update(jax.tree_map(jnp.mean, dual_metrics))
684+
metrics.update(jax.tree.map(jnp.mean, dual_metrics))
685685

686686
return new_state, metrics
687687

@@ -733,7 +733,7 @@ def get_variables(self, names: List[str]) -> network_lib.Params:
733733
return [variables[name] for name in names]
734734

735735
def save(self) -> TrainingState:
736-
return jax.tree_map(mpo_utils.get_from_first_device, self._state)
736+
return jax.tree.map(mpo_utils.get_from_first_device, self._state)
737737

738738
def restore(self, state: TrainingState):
739739
self._state = utils.replicate_in_all_devices(state, self._local_devices)

acme/agents/jax/mpo/networks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def critic_fn(observation: types.NestedArray,
259259
def add_batch(nest, batch_size: Optional[int]):
260260
"""Adds a batch dimension at axis 0 to the leaves of a nested structure."""
261261
broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape)
262-
return jax.tree_map(broadcast, nest)
262+
return jax.tree.map(broadcast, nest)
263263

264264

265265
def w_init_identity(shape: Sequence[int], dtype) -> jnp.ndarray:

acme/agents/jax/mpo/utils.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _slice_and_maybe_to_numpy(x):
4242
x = x[0]
4343
return _fetch_devicearray(x) if as_numpy else x
4444

45-
return jax.tree_map(_slice_and_maybe_to_numpy, nest)
45+
return jax.tree.map(_slice_and_maybe_to_numpy, nest)
4646

4747

4848
def rolling_window(x: jnp.ndarray,
@@ -80,11 +80,11 @@ def tree_map_distribution(
8080
if isinstance(x, distrax.Distribution):
8181
safe_f = lambda y: f(y) if isinstance(y, jnp.ndarray) else y
8282
nil, tree_data = x.tree_flatten()
83-
new_tree_data = jax.tree_map(safe_f, tree_data)
83+
new_tree_data = jax.tree.map(safe_f, tree_data)
8484
new_x = x.tree_unflatten(new_tree_data, nil)
8585
return new_x
8686
elif isinstance(x, tfd.Distribution):
87-
return jax.tree_map(f, x)
87+
return jax.tree.map(f, x)
8888
else:
8989
return f(x)
9090

@@ -95,8 +95,9 @@ def make_sequences_from_transitions(
9595
"""Convert a batch of transitions into a batch of 1-step sequences."""
9696
stack = lambda x, y: jnp.stack((x, y), axis=num_batch_dims)
9797
duplicate = lambda x: stack(x, x)
98-
observation = jax.tree_map(stack, transitions.observation,
99-
transitions.next_observation)
98+
observation = jax.tree.map(
99+
stack, transitions.observation, transitions.next_observation
100+
)
100101
reward = duplicate(transitions.reward)
101102

102103
return adders.Step( # pytype: disable=wrong-arg-types # jnp-type
@@ -105,5 +106,5 @@ def make_sequences_from_transitions(
105106
reward=reward,
106107
discount=duplicate(transitions.discount),
107108
start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_),
108-
extras=jax.tree_map(duplicate, transitions.extras),
109+
extras=jax.tree.map(duplicate, transitions.extras),
109110
)

acme/agents/jax/r2d2/learning.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ def loss(
101101

102102
# Maybe burn the core state in.
103103
if burn_in_length:
104-
burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation)
104+
burn_obs = jax.tree.map(lambda x: x[:burn_in_length], data.observation)
105105
key_grad, key1, key2 = jax.random.split(key_grad, 3)
106106
_, online_state = networks.unroll(params, key1, burn_obs, online_state)
107107
_, target_state = networks.unroll(target_params, key2, burn_obs,
108108
target_state)
109109

110110
# Only get data to learn on from after the end of the burn in period.
111-
data = jax.tree_map(lambda seq: seq[burn_in_length:], data)
111+
data = jax.tree.map(lambda seq: seq[burn_in_length:], data)
112112

113113
# Unroll on sequences to get online and target Q-Values.
114114
key1, key2 = jax.random.split(key_grad)

acme/agents/jax/sac/learning.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ def update_step(
176176
critic_grads, state.q_optimizer_state)
177177
q_params = optax.apply_updates(state.q_params, critic_update)
178178

179-
new_target_q_params = jax.tree_map(lambda x, y: x * (1 - tau) + y * tau,
180-
state.target_q_params, q_params)
179+
new_target_q_params = jax.tree.map(
180+
lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params
181+
)
181182

182183
metrics = {
183184
'critic_loss': critic_loss,

acme/datasets/tfds.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self,
137137
# we capture the whole dataset.
138138
size = _dataset_size_upperbound(dataset)
139139
data = next(dataset.batch(size).as_numpy_iterator())
140-
self._dataset_size = jax.tree_flatten(
140+
self._dataset_size = jax.tree.flatten(
141141
jax.tree_util.tree_map(lambda x: x.shape[0], data)
142142
)[0][0]
143143
device = jax_utils._pmap_device_order()

0 commit comments

Comments
 (0)