Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568520617
Change-Id: I372051fa57315aa4710e842a0fc4582c685a78c6
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Sep 26, 2023
1 parent 46c57d9 commit ac668d5
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
8 changes: 4 additions & 4 deletions acme/agents/jax/ail/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def loss_fn(
'entropy_loss': entropy_loss,
'classification_loss': classification_loss
}
return total_loss, (metrics, discriminator_state)
return total_loss, (metrics, discriminator_state) # pytype: disable=bad-return-type # jnp-type

return loss_fn

Expand Down Expand Up @@ -166,7 +166,7 @@ def loss_fn(
'entropy_loss': entropy_loss,
'classification_loss': classification_loss
}
return total_loss, (metrics, discriminator_state)
return total_loss, (metrics, discriminator_state) # pytype: disable=bad-return-type # jnp-type

return loss_fn

Expand Down Expand Up @@ -194,7 +194,7 @@ def _compute_gradient_penalty(gradient_penalty_data: types.Transition,
gradients.next_observation])
gradient_norms = jnp.linalg.norm(gradients + 1e-8)
k = gradient_penalty_target * jnp.ones_like(gradient_norms)
return jnp.mean(jnp.square(gradient_norms - k))
return jnp.mean(jnp.square(gradient_norms - k)) # pytype: disable=bad-return-type # jnp-type


def add_gradient_penalty(base_loss: Loss,
Expand Down Expand Up @@ -231,6 +231,6 @@ def apply_discriminator_fn(transitions: types.Transition) -> float:
total_loss = partial_loss + gradient_penalty
losses['total_loss'] = total_loss

return total_loss, (losses, discriminator_state)
return total_loss, (losses, discriminator_state) # pytype: disable=bad-return-type # jnp-type

return loss_fn
2 changes: 1 addition & 1 deletion acme/agents/jax/ail/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ def imitation_reward(logits: networks_lib.Logits) -> float:
# pylint: disable=invalid-unary-operand-type
rewards = jnp.clip(
rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude)
return rewards
return rewards # pytype: disable=bad-return-type # jnp-type

return imitation_reward # pytype: disable=bad-return-type # jax-ndarray
4 changes: 2 additions & 2 deletions acme/agents/jax/mpo/categorical_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __call__(
loss = loss_policy + loss_kl + loss_dual

# Create statistics.
stats = CategoricalMPOStats(
stats = CategoricalMPOStats( # pytype: disable=wrong-arg-types # jnp-type
# Dual Variables.
dual_alpha=jnp.mean(alpha),
dual_temperature=jnp.mean(temperature),
Expand All @@ -183,7 +183,7 @@ def __call__(
q_min=jnp.mean(jnp.min(q_values, axis=0)),
q_max=jnp.mean(jnp.max(q_values, axis=0)),
entropy_online=jnp.mean(online_action_distribution.entropy()),
entropy_target=jnp.mean(target_action_distribution.entropy())
entropy_target=jnp.mean(target_action_distribution.entropy()),
)

return loss, stats
Expand Down
5 changes: 3 additions & 2 deletions acme/agents/jax/mpo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ def make_sequences_from_transitions(
transitions.next_observation)
reward = duplicate(transitions.reward)

return adders.Step(
return adders.Step( # pytype: disable=wrong-arg-types # jnp-type
observation=observation,
action=duplicate(transitions.action),
reward=reward,
discount=duplicate(transitions.discount),
start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_),
extras=jax.tree_map(duplicate, transitions.extras))
extras=jax.tree_map(duplicate, transitions.extras),
)
5 changes: 3 additions & 2 deletions acme/agents/jax/r2d2/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,13 @@ def update_priorities(
logging.info('Total number of params: %d',
sum(tree.flatten(sizes.values())))

state = TrainingState(
state = TrainingState( # pytype: disable=wrong-arg-types # jnp-type
params=initial_params,
target_params=initial_params,
opt_state=opt_state,
steps=jnp.array(0),
random_key=random_key)
random_key=random_key,
)
# Replicate parameters.
self._state = utils.replicate_in_all_devices(state)

Expand Down
2 changes: 1 addition & 1 deletion acme/agents/jax/rnd/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def rnd_loss(
predictor_output = networks.predictor.apply(predictor_params,
transitions.observation,
transitions.action)
return jnp.mean(jnp.square(target_output - predictor_output))
return jnp.mean(jnp.square(target_output - predictor_output)) # pytype: disable=bad-return-type # jnp-type


class RNDLearner(acme.Learner):
Expand Down

0 comments on commit ac668d5

Please sign in to comment.