Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @no_remat decorator #94

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion axlearn/common/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PartitionSpec,
Tensor,
TensorSpec,
check_jax_type,
flatten_items,
get_or_none,
)
Expand Down Expand Up @@ -185,9 +186,24 @@ def _call_thunk(self, *args, method_fn, **kwargs) -> Callable[[], Any]:
)

nullary = super()._call_thunk(*args, method_fn=method_fn, **kwargs)
if current_context() is None or cfg.remat_spec is None or not self.is_training:
if (
current_context() is None
or cfg.remat_spec is None
or not self.is_training
or getattr(method_fn, "_no_remat", False)
):
return nullary

# Remat always uses abstract tracers even if concrete information is available.
# This means that all inputs and outputs to a remat function need to be JAX types.
# We print a nice error if the inputs are not.
check_jax_type(
args=args,
kwargs=kwargs,
msg=f"Attempt to use remat on {self}.{method_fn} "
"failed. Consider decorating with @no_remat.",
)

def nullary_with_remat():
def fn(*args, **kwargs):
"""Unlike self.method, fn returns (outputs, output_collection)."""
Expand Down Expand Up @@ -396,3 +412,27 @@ def _add_activation_summary(
self.add_summary(f"activations/{name}_mean", WeightedScalar(activations_mean, weights))
# Average of per hidden unit norm.
self.add_summary(f"activations/{name}_norm", WeightedScalar(activations_norm_mean, weights))


def no_remat(fn: Callable) -> Callable:
"""Annotates fn so that remat will not be applied to it.

This can be used to prevent tracers from leaking into helper methods that depend
only on data available at compile time when using `remat_spec`. For example, the following
method cannot be used in a class that uses remat_spec without using @no_remat:

```
def fn(self, st: str):
if st=='three':
return 3
```

Args:
fn: The method to annotate.

Returns:
The input `fn` after having been annotated.
"""
# pylint: disable=protected-access
fn._no_remat = True
return fn
56 changes: 55 additions & 1 deletion axlearn/common/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""Tests BaseLayer."""
import math
from functools import partial
from typing import Dict, List
from typing import Dict, List, Optional

import jax.ad_checkpoint
import jax.core
import jax.interpreters.ad
import jax.random
Expand All @@ -20,6 +21,7 @@
ParameterNoise,
ParameterSpec,
RematSpec,
no_remat,
)
from axlearn.common.config import config_class
from axlearn.common.module import Module, OutputCollection
Expand Down Expand Up @@ -404,6 +406,58 @@ def test_activation_summary_toy_example(self, with_paddings):
rtol=1e-6,
)

def test_no_remat_inheritance(self):
# Check that @no_remat is preserved by inheritance unless the method
# is explicitly overriden by one without @no_remat.
class AnotherTestLayer(BaseLayer):
@no_remat
def fn(self, st: str):
pass

class Subclass1(AnotherTestLayer):
pass

class Subclass2(AnotherTestLayer):
def fn(self, st: str):
pass

self.assertTrue(hasattr(AnotherTestLayer.fn, "_no_remat"))
self.assertTrue(hasattr(Subclass1.fn, "_no_remat"))
self.assertFalse(hasattr(Subclass2.fn, "_no_remat"))

def test_no_remat(self):
# pylint: disable=missing-class-docstring
# Checks that using @no_remat allows calling a function with a non-JAX type.
class AnotherTestLayer(BaseLayer):
@config_class
class Config(BaseLayer.Config):
remat_spec: Optional[RematSpec] = RematSpec(
policy=jax_remat_policies.nothing_saveable
)

def forward(self, x):
b = self.fn("three")
x = b * x
return x.sum()

@no_remat
def fn(self, st: str):
if st == "three":
return 3

# Pytype doesn't like us directly accessing the _no_remat attribute, so we use getattr.
self.assertTrue(getattr(AnotherTestLayer.fn, "_no_remat", False))

layer = AnotherTestLayer.default_config().set(name="tmp").instantiate(parent=None)
params = {}
rng = jax.random.PRNGKey(0)
jit_value_and_grad = jax.jit(
lambda *args, inputs, **kwargs: jax.value_and_grad(
lambda inputs: F(layer, *args, inputs=inputs, is_training=True, **kwargs)[0]
)(inputs)
)
_ = jit_value_and_grad(prng_key=rng, state=params, inputs=[jax.numpy.ones(5)])


class ComputeFanAxesTest(TestCase):
"""Tests compute_fan_axes."""
Expand Down
37 changes: 37 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,43 @@ def check_param_shape_alignment(
return "\n".join(output_str)


def check_jax_type(
*,
args: Optional[Sequence] = None,
kwargs: Optional[dict] = None,
pretty_named_args: Optional[dict] = None,
msg: Optional[str] = None,
):
"""Checks that the supplied arguments are valid JAX types and raise ValueError if not.

Args:
args: Positional arguments of a function call to check.
kwargs: Keyword arguments of a function call to check.
pretty_named_args: Arguments that already have a human readable name to check.
msg: A prefix to print with a line break before the error message produced by this function.

Raises:
ValueError: If the supplied arguments are not valid jax types.
"""
if pretty_named_args is None:
pretty_named_args = {}
if args is not None:
pretty_named_args.update({f"args[{i}]": args[i] for i in range(len(args))})
if kwargs is not None:
pretty_named_args.update({f"kwargs[{key}]": kwargs[key] for key in kwargs})

for name, arg in pretty_named_args.items():
values, _ = jax.tree_util.tree_flatten(arg)
for value in values:
if not isinstance(value, (type(None), jax.Array, int, float)):
if msg is None:
msg = ""
else:
msg += "\n"
msg += f"Argument {name} has leaf with non-JAX type {type(value)}"
raise ValueError(msg)


def validate_float_dtype(dtype: jnp.dtype):
"""Validates if the provided dtype is both a float and amongst the set supported.

Expand Down
10 changes: 10 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
as_numpy_array,
as_tensor,
cast_floats,
check_jax_type,
check_param_shape_alignment,
complete_partition_spec_tree,
copy_recursively,
Expand Down Expand Up @@ -531,6 +532,15 @@ def test_check_param_shape_alignment(self):
error_msg = "(linear1/weight/0) shape is different: source: (32), target: (15)."
self.assertEqual(error_msg, check_param_shape_alignment(target_tree, misalign_target_tree))

def test_check_jax_type(self):
check_jax_type(args=(1, 1.0, jax.numpy.ones(1), None, [{"key": 1}]))
with self.assertRaisesRegex(ValueError, "non-JAX type"):
check_jax_type(args=([{"key": "1"}],))
with self.assertRaisesRegex(ValueError, "non-JAX type"):
check_jax_type(kwargs={"key": "1"})
with self.assertRaisesRegex(ValueError, "^Argument key has leaf with non-JAX type"):
check_jax_type(pretty_named_args={"key": "1"})


class SimilarNamesTest(TestCase):
@parameterized.parameters(
Expand Down