Skip to content

Commit

Permalink
JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 425896973
Jake VanderPlas authored and diegolascasas committed Feb 16, 2022
1 parent bc869b2 commit 35be4a0
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions kfac_ferminet_alpha/tests/graph_matcher_test.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ def tagged_autoencoder(all_params, x_in):
return [[h1, t2], [h2, t2]]


@jtu.with_config(jax_numpy_rank_promotion="allow")
class TestGraphMatcher(jtu.JaxTestCase):
"""Class for running all of the tests for integrating the systems."""

1 change: 1 addition & 0 deletions kfac_ferminet_alpha/tests/tracer_test.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@ def autoencoder_aux(all_aux, all_params, x_in):
return [l1, l2 * 0.1], layers_values


@jtu.with_config(jax_numpy_rank_promotion="allow")
class TestTracer(jtu.JaxTestCase):
"""Class for running all of the tests for integrating the systems."""

0 comments on commit 35be4a0

Please sign in to comment.