Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support customized mesh rules to support different HWs #696

Merged
merged 13 commits into from
Sep 13, 2024
8 changes: 8 additions & 0 deletions axlearn/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,11 @@ def maybe_set_config(cfg: _ConfigBase, **kwargs) -> _ConfigBase:
if hasattr(cfg, key):
setattr(cfg, key, value)
return cfg


class ConfigModifier(Configurable):
"""A class that takes a config and returns a modified config."""

def __call__(self, cfg: InstantiableConfig[T]) -> InstantiableConfig[T]:
"""A function that modifies the input config, should be defined by subclasses."""
return cfg
14 changes: 10 additions & 4 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,9 +1069,15 @@ def select_mesh_config(trainer_config: SpmdTrainer.Config, *, mesh_selector: str
mesh_selector: A string used to select the mesh rule to apply.
"""
if trainer_config.mesh_rules:
mesh = match_regex_rules(
mesh_rule = match_regex_rules(
mesh_selector, rules=trainer_config.mesh_rules, default_value=REQUIRED
)
logging.info("Mesh selector %s matches mesh rule %s", mesh_selector, mesh)
if mesh is not REQUIRED:
trainer_config.mesh_shape = mesh
logging.info("Mesh selector %s matches mesh rule %s", mesh_selector, mesh_rule)
if mesh_rule is not REQUIRED:
# Mesh config is just mesh rule or hybrid mesh rule.
if isinstance(mesh_rule, (tuple, HybridMeshShape)) or mesh_rule is None:
trainer_config.mesh_shape = mesh_rule
else:
# Override configs from ConfigModifier.
mesh_rule_fn = maybe_instantiate(mesh_rule)
trainer_config = mesh_rule_fn(trainer_config)
150 changes: 150 additions & 0 deletions axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright © 2023 Apple Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright © 2023 Apple Inc.
# Copyright © 2024 Apple Inc.

(here and elsewhere)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks for the review, I need one more approval after the nit change.


"""Defines trainer config modifiers, which will be used in model definitions."""
markblee marked this conversation as resolved.
Show resolved Hide resolved

from typing import Dict, List, Optional, Union

from axlearn.common import config
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import REQUIRED, ConfigModifier, ConfigOr, Required, config_class
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.metrics import MetricAccumulator
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import HybridMeshShape, MeshShape


class GradientAccumulation(ConfigModifier):
markblee marked this conversation as resolved.
Show resolved Hide resolved
"""Accumulate gradients for grad_acc_steps steps."""

@config_class
class Config(ConfigModifier.Config):
grad_acc_steps: Required[int] = REQUIRED
metric_accumulator: Required[MetricAccumulator.Config] = MetricAccumulator.default_config()
markblee marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._grad_acc_steps = cfg.grad_acc_steps
self._metric_accumulator = cfg.metric_accumulator

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the forward_fn_transformation to accumulate gradients for grad_acc_steps steps.

Note this would not affect the global batch size or the logical training steps.
The optimization step is applied each time after grad_acc_steps steps of
forward and backward passes on mini-batches.

global_bs=mini_bs*grad_acc_steps
train_steps=mini_steps/grad_acc_steps

Args:
cfg: the trainer config to be modified.
markblee marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The modified trainer config.
"""
cfg.learner.forward_fn_transformation = config.config_for_function(
with_minibatch_steps
).set(
steps=self._grad_acc_steps,
metric_accumulator=self._metric_accumulator,
)
return cfg


class RematSpecModifier(ConfigModifier):
"""Update the remat policies for specified modules."""

@config_class
class Config(ConfigModifier.Config):
remat_policies: Optional[Dict[str, RematSpec]] = None
markblee marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._remat_policies = cfg.remat_policies

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Update the remat policy for the specified modules.

Args:
cfg (SpmdTrainer.Config): the trainer config to be modified.
markblee marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: the target module is not found.
ValueError: the remat_spec attribute is not found.

Returns:
The modified trainer config.
"""
if self._remat_policies is None:
return cfg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._remat_policies is None:
return cfg

for module_name, remat_spec in self._remat_policies.items():
# Here we assume x.y.z format.
# One example would be model.decoder.transformer.layer.
target_modules = module_name.split(".")
curr_module = cfg
for target_module in target_modules:
if not hasattr(curr_module, target_module):
raise ValueError(f"{target_module} is not found in {curr_module}.")
curr_module = getattr(curr_module, target_module)
# Here we assume all modules have remat_spec attribute.
if not hasattr(curr_module, "remat_spec"):
raise ValueError(f"{curr_module} does not have remat_spec attribute")
curr_module.remat_spec = remat_spec
return cfg


class MeshShapeModifier(ConfigModifier):
"""Update the mesh_shape for the trainer config."""

@config_class
class Config(ConfigModifier.Config):
mesh_shape: Optional[Union[MeshShape, HybridMeshShape]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._mesh_shape = cfg.mesh_shape

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Overwrite the mesh shape.

Args:
cfg (SpmdTrainer.Config): the trainer config to be modified.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.


Returns:
The modified trainer config.
"""
cfg.mesh_shape = self._mesh_shape
return cfg


class ChainConfigModifier(ConfigModifier):
"""Chain multiple config modifiers together."""

@config_class
class Config(ConfigModifier.Config):
config_modifiers: Required[List[ConfigOr[ConfigModifier]]] = REQUIRED
markblee marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._config_modifiers = [
cfg_modifier.instantiate() for cfg_modifier in cfg.config_modifiers
markblee marked this conversation as resolved.
Show resolved Hide resolved
]

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Chain multiple config modifiers together.
The config modifiers will be applied in the order they are provided.

Args:
cfg (SpmdTrainer.Config): the trainer config to be modified.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


Returns:
The modified trainer config.
"""
for config_modifier_fn in self._config_modifiers:
cfg = config_modifier_fn(cfg)
return cfg
93 changes: 93 additions & 0 deletions axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright © 2023 Apple Inc.

"""Test various ConfigModifier classes in trainer_config_modifier.py."""

import jax
from absl.testing import absltest

from axlearn.common import test_utils
from axlearn.common.base_layer import RematSpec
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulation,
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.trainer_test import DummyModel


class GradientAccumulationTest(test_utils.TestCase):
def test_gradient_accumulation_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = GradientAccumulation.default_config().set(grad_acc_steps=4).instantiate()
cfg = cfg_modifier(cfg)
self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4)


class RematSpecModifierTest(test_utils.TestCase):
def test_remat_policy_override(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
RematSpecModifier.default_config()
.set(
remat_policies={
"model.linear": RematSpec(
prevent_cse=True,
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable,
),
}
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertRegex(str(cfg.model.linear), "dots_saveable")
cfg_modifier = (
RematSpecModifier.default_config()
.set(
remat_policies={
"model.linear": RematSpec(
prevent_cse=True,
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable,
),
"model.unknown": RematSpec(
prevent_cse=True,
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable,
),
}
)
.instantiate()
)
# Ensure that the exception is working.
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"):
_ = cfg_modifier(cfg)


class MeshShapeModifierTest(test_utils.TestCase):
def test_mesh_shape_update(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)).instantiate()
cfg = cfg_modifier(cfg)
self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1))


class ChainConfigModifierTest(test_utils.TestCase):
def test_chain_config_modifier(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
cfg_modifier = (
ChainConfigModifier.default_config()
.set(
config_modifiers=[
GradientAccumulation.default_config().set(grad_acc_steps=4),
MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)),
]
)
.instantiate()
)
cfg = cfg_modifier(cfg)
self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1))
self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4)


if __name__ == "__main__":
absltest.main()
55 changes: 53 additions & 2 deletions axlearn/common/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
struct_test,
test_utils,
)
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, RematSpec
from axlearn.common.base_model import BaseModel
from axlearn.common.checkpointer import (
Checkpointer,
Expand All @@ -47,6 +47,12 @@
from axlearn.common.module import Module
from axlearn.common.state_builder import Builder as TrainerStateBuilder
from axlearn.common.trainer import SpmdTrainer, TrainerState, select_mesh_config
from axlearn.common.trainer_config_modifier import (
ChainConfigModifier,
GradientAccumulation,
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.utils import (
Nested,
NestedTensor,
Expand Down Expand Up @@ -172,13 +178,14 @@ class Config(BaseModel.Config):

# Whether to explicitly init dummy state to test pruning backwards compat.
init_dummy_state: bool = False
linear: layers.Linear.Config = layers.Linear.default_config()

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
cfg = self.config
self._add_child(
"fc",
layers.Linear.default_config().set(
cfg.linear.set(
input_dim=3,
output_dim=NUM_CLASSES,
bias=True,
Expand Down Expand Up @@ -921,6 +928,50 @@ def test_select_mesh_config(self):
self.assertIsNone(cfg.mesh_shape)


class SelectExtendedMeshConfigTest(test_utils.TestCase):
def test_select_mesh_config(self):
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config())
self.assertIs(cfg.mesh_shape, REQUIRED)

# When mesh_rules=None.
self.assertIsNone(cfg.mesh_rules)
select_mesh_config(cfg, mesh_selector="tpu-v4-128")
# cfg.mesh_shape remains unchanged.
self.assertIs(cfg.mesh_shape, REQUIRED)

# When no mesh rule matches the selector.
cfg.mesh_rules = (
(
"tpu-v4-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)),
RematSpecModifier.default_config().set(
remat_policies={
"model.linear": RematSpec(
prevent_cse=True,
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable,
),
}
),
GradientAccumulation.default_config().set(grad_acc_steps=4),
],
),
),
)
select_mesh_config(cfg, mesh_selector="tpu-v4-128")
# cfg.mesh_shape still remains unchanged.
self.assertIs(cfg.mesh_shape, REQUIRED)
# When there is a match.
select_mesh_config(cfg, mesh_selector="tpu-v4-64")
# cfg.mesh_shape is overridden.
self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1))
# Check if gradient accumulation is set up.
self.assertRegex(str(cfg.learner.forward_fn_transformation), "steps: 4")
# Check if remat policy is set up.
self.assertRegex(str(cfg.model.linear), "dots_saveable")


class CompatibilityTest(test_utils.TestCase):
def test_chex_serialization_compatibility(self):
"""Tests that a chex.dataclass that has been serialized as part of an AXLearn checkpoint
Expand Down
Loading