Skip to content

Commit

Permalink
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 512349622
  • Loading branch information
hawkinsp authored and saran-t committed Jun 2, 2023
1 parent c051e6a commit 6f0ddef
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions physics_inspired_models/models/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _models_core(
is_training=is_training)
return p_x, z0, decoder_z

def training_objectives(
def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
self,
params: hk.Params,
state: hk.State,
Expand Down Expand Up @@ -300,7 +300,7 @@ def reconstruct(
include_z0=False,
)[0]

def gt_state_and_latents(
def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
self,
params: hk.Params,
rng: jnp.ndarray,
Expand Down Expand Up @@ -336,7 +336,7 @@ def _init_non_model_params_and_state(
) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]:
return dict(), dict()

def _init_latent_system(
def _init_latent_system( # pytype: disable=signature-mismatch # jax-ndarray
self,
rng: jnp.ndarray,
z: jnp.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion physics_inspired_models/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _init(
params = hk.data_structures.to_immutable_dict(params)
state = hk.data_structures.to_immutable_dict(state)

return params, state
return params, state # pytype: disable=bad-return-type # jax-ndarray

def init(
self,
Expand Down
10 changes: 5 additions & 5 deletions physics_inspired_models/models/deterministic_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def process_latents_for_dynamics(self, z: jnp.ndarray) -> _ArrayOrPhase:

def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray:
if self.latent_dynamics_type == "Physics":
return z.q if self.render_from_q_only else z.single_state
return z
return z.q if self.render_from_q_only else z.single_state # pytype: disable=attribute-error # jax-ndarray
return z # pytype: disable=bad-return-type # jax-ndarray

@property
def inferred_index(self) -> int:
Expand Down Expand Up @@ -327,7 +327,7 @@ def verify_unroll_args(
if num_steps_backward > 0 and not self.can_run_backwards:
raise ValueError("This model can not be unrolled backward in time.")

def unroll_latent_dynamics(
def unroll_latent_dynamics( # pytype: disable=signature-mismatch # jax-ndarray
self,
z: phase_space.PhaseSpace,
params: hk.Params,
Expand Down Expand Up @@ -393,7 +393,7 @@ def _models_core(
z = z.single_state if isinstance(z, phase_space.PhaseSpace) else z
return p_x, q_z, self.prior(), z0, z, dyn_stats

def training_objectives(
def training_objectives( # pytype: disable=signature-mismatch # jax-ndarray
self,
params: utils.Params,
state: hk.State,
Expand Down Expand Up @@ -532,7 +532,7 @@ def reconstruct(
include_z0=True,
)[0]

def gt_state_and_latents(
def gt_state_and_latents( # pytype: disable=signature-mismatch # jax-ndarray
self,
params: hk.Params,
rng: jnp.ndarray,
Expand Down
4 changes: 2 additions & 2 deletions physics_inspired_models/models/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,12 +610,12 @@ def simulate(
y.q, y.p, **nets_kwargs)
# Special Haiku magic to avoid tracer issues
if hk.running_init():
return self.lagrangian(y0, **nets_kwargs)
return self.lagrangian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray
else:
hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
if hk.running_init():
return self.hamiltonian(y0, **nets_kwargs)
return self.hamiltonian(y0, **nets_kwargs) # pytype: disable=bad-return-type # jax-ndarray

# Optionally switch coordinate frame
if self.input_space == "velocity" and self.simulation_space == "momentum":
Expand Down

0 comments on commit 6f0ddef

Please sign in to comment.