diff --git a/axlearn/common/base_layer.py b/axlearn/common/base_layer.py index 0e7f32ac6..b1faa20ee 100644 --- a/axlearn/common/base_layer.py +++ b/axlearn/common/base_layer.py @@ -19,6 +19,7 @@ PartitionSpec, Tensor, TensorSpec, + check_jax_type, flatten_items, get_or_none, ) @@ -185,9 +186,24 @@ def _call_thunk(self, *args, method_fn, **kwargs) -> Callable[[], Any]: ) nullary = super()._call_thunk(*args, method_fn=method_fn, **kwargs) - if current_context() is None or cfg.remat_spec is None or not self.is_training: + if ( + current_context() is None + or cfg.remat_spec is None + or not self.is_training + or getattr(method_fn, "_no_remat", False) + ): return nullary + # Remat always uses abstract tracers even if concrete information is available. + # This means that all inputs and outputs to a remat function need to be JAX types. + # We print a nice error if the inputs are not. + check_jax_type( + args=args, + kwargs=kwargs, + msg=f"Attempt to use remat on {self}.{method_fn} " + "failed. Consider decorating with @no_remat.", + ) + def nullary_with_remat(): def fn(*args, **kwargs): """Unlike self.method, fn returns (outputs, output_collection).""" @@ -396,3 +412,27 @@ def _add_activation_summary( self.add_summary(f"activations/{name}_mean", WeightedScalar(activations_mean, weights)) # Average of per hidden unit norm. self.add_summary(f"activations/{name}_norm", WeightedScalar(activations_norm_mean, weights)) + + +def no_remat(fn: Callable) -> Callable: + """Annotates fn so that remat will not be applied to it. + + This can be used to prevent tracers from leaking into helper methods that depend + only on data available at compile time when using `remat_spec`. For example, the following + method cannot be used in a class that uses remat_spec without using @no_remat: + + ``` + def fn(self, st: str): + if st=='three': + return 3 + ``` + + Args: + fn: The method to annotate. + + Returns: + The input `fn` after having been annotated. + """ + # pylint: disable=protected-access + fn._no_remat = True + return fn diff --git a/axlearn/common/base_layer_test.py b/axlearn/common/base_layer_test.py index 214793de0..ca2eefa71 100644 --- a/axlearn/common/base_layer_test.py +++ b/axlearn/common/base_layer_test.py @@ -3,8 +3,9 @@ """Tests BaseLayer.""" import math from functools import partial -from typing import Dict, List +from typing import Dict, List, Optional +import jax.ad_checkpoint import jax.core import jax.interpreters.ad import jax.random @@ -20,6 +21,7 @@ ParameterNoise, ParameterSpec, RematSpec, + no_remat, ) from axlearn.common.config import config_class from axlearn.common.module import Module, OutputCollection @@ -404,6 +406,58 @@ def test_activation_summary_toy_example(self, with_paddings): rtol=1e-6, ) + def test_no_remat_inheritance(self): + # Check that @no_remat is preserved by inheritance unless the method + # is explicitly overriden by one without @no_remat. + class AnotherTestLayer(BaseLayer): + @no_remat + def fn(self, st: str): + pass + + class Subclass1(AnotherTestLayer): + pass + + class Subclass2(AnotherTestLayer): + def fn(self, st: str): + pass + + self.assertTrue(hasattr(AnotherTestLayer.fn, "_no_remat")) + self.assertTrue(hasattr(Subclass1.fn, "_no_remat")) + self.assertFalse(hasattr(Subclass2.fn, "_no_remat")) + + def test_no_remat(self): + # pylint: disable=missing-class-docstring + # Checks that using @no_remat allows calling a function with a non-JAX type. + class AnotherTestLayer(BaseLayer): + @config_class + class Config(BaseLayer.Config): + remat_spec: Optional[RematSpec] = RematSpec( + policy=jax_remat_policies.nothing_saveable + ) + + def forward(self, x): + b = self.fn("three") + x = b * x + return x.sum() + + @no_remat + def fn(self, st: str): + if st == "three": + return 3 + + # Pytype doesn't like us directly accessing the _no_remat attribute, so we use getattr. + self.assertTrue(getattr(AnotherTestLayer.fn, "_no_remat", False)) + + layer = AnotherTestLayer.default_config().set(name="tmp").instantiate(parent=None) + params = {} + rng = jax.random.PRNGKey(0) + jit_value_and_grad = jax.jit( + lambda *args, inputs, **kwargs: jax.value_and_grad( + lambda inputs: F(layer, *args, inputs=inputs, is_training=True, **kwargs)[0] + )(inputs) + ) + _ = jit_value_and_grad(prng_key=rng, state=params, inputs=[jax.numpy.ones(5)]) + class ComputeFanAxesTest(TestCase): """Tests compute_fan_axes.""" diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 577af565d..ed52c7767 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -778,6 +778,43 @@ def check_param_shape_alignment( return "\n".join(output_str) +def check_jax_type( + *, + args: Optional[Sequence] = None, + kwargs: Optional[dict] = None, + pretty_named_args: Optional[dict] = None, + msg: Optional[str] = None, +): + """Checks that the supplied arguments are valid JAX types and raise ValueError if not. + + Args: + args: Positional arguments of a function call to check. + kwargs: Keyword arguments of a function call to check. + pretty_named_args: Arguments that already have a human readable name to check. + msg: A prefix to print with a line break before the error message produced by this function. + + Raises: + ValueError: If the supplied arguments are not valid jax types. + """ + if pretty_named_args is None: + pretty_named_args = {} + if args is not None: + pretty_named_args.update({f"args[{i}]": args[i] for i in range(len(args))}) + if kwargs is not None: + pretty_named_args.update({f"kwargs[{key}]": kwargs[key] for key in kwargs}) + + for name, arg in pretty_named_args.items(): + values, _ = jax.tree_util.tree_flatten(arg) + for value in values: + if not isinstance(value, (type(None), jax.Array, int, float)): + if msg is None: + msg = "" + else: + msg += "\n" + msg += f"Argument {name} has leaf with non-JAX type {type(value)}" + raise ValueError(msg) + + def validate_float_dtype(dtype: jnp.dtype): """Validates if the provided dtype is both a float and amongst the set supported. diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 8beae6e8d..c0a75d18c 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -44,6 +44,7 @@ as_numpy_array, as_tensor, cast_floats, + check_jax_type, check_param_shape_alignment, complete_partition_spec_tree, copy_recursively, @@ -531,6 +532,15 @@ def test_check_param_shape_alignment(self): error_msg = "(linear1/weight/0) shape is different: source: (32), target: (15)." self.assertEqual(error_msg, check_param_shape_alignment(target_tree, misalign_target_tree)) + def test_check_jax_type(self): + check_jax_type(args=(1, 1.0, jax.numpy.ones(1), None, [{"key": 1}])) + with self.assertRaisesRegex(ValueError, "non-JAX type"): + check_jax_type(args=([{"key": "1"}],)) + with self.assertRaisesRegex(ValueError, "non-JAX type"): + check_jax_type(kwargs={"key": "1"}) + with self.assertRaisesRegex(ValueError, "^Argument key has leaf with non-JAX type"): + check_jax_type(pretty_named_args={"key": "1"}) + class SimilarNamesTest(TestCase): @parameterized.parameters(