-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support customized mesh rules to support different HWs #696
Changes from 4 commits
b671d44
227dd8e
e8b0399
36a3738
1fe370c
55c8abb
c0c9518
42f1084
7d8be2e
2f55be4
37ee8ef
ddb21ff
90e59eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,150 @@ | ||||||
# Copyright © 2023 Apple Inc. | ||||||
|
||||||
"""Defines trainer config modifiers, which will be used in model definitions.""" | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
from typing import Dict, List, Optional, Union | ||||||
|
||||||
from axlearn.common import config | ||||||
from axlearn.common.base_layer import RematSpec | ||||||
from axlearn.common.config import REQUIRED, ConfigModifier, ConfigOr, Required, config_class | ||||||
from axlearn.common.gradient_accumulation import with_minibatch_steps | ||||||
from axlearn.common.metrics import MetricAccumulator | ||||||
from axlearn.common.trainer import SpmdTrainer | ||||||
from axlearn.common.utils import HybridMeshShape, MeshShape | ||||||
|
||||||
|
||||||
class GradientAccumulation(ConfigModifier): | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Accumulate gradients for grad_acc_steps steps.""" | ||||||
|
||||||
@config_class | ||||||
class Config(ConfigModifier.Config): | ||||||
grad_acc_steps: Required[int] = REQUIRED | ||||||
metric_accumulator: Required[MetricAccumulator.Config] = MetricAccumulator.default_config() | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
def __init__(self, cfg: Config): | ||||||
super().__init__(cfg) | ||||||
cfg = self.config | ||||||
self._grad_acc_steps = cfg.grad_acc_steps | ||||||
self._metric_accumulator = cfg.metric_accumulator | ||||||
|
||||||
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | ||||||
"""Overwrite the forward_fn_transformation to accumulate gradients for grad_acc_steps steps. | ||||||
|
||||||
Note this would not affect the global batch size or the logical training steps. | ||||||
The optimization step is applied each time after grad_acc_steps steps of | ||||||
forward and backward passes on mini-batches. | ||||||
|
||||||
global_bs=mini_bs*grad_acc_steps | ||||||
train_steps=mini_steps/grad_acc_steps | ||||||
|
||||||
Args: | ||||||
cfg: the trainer config to be modified. | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Returns: | ||||||
The modified trainer config. | ||||||
""" | ||||||
cfg.learner.forward_fn_transformation = config.config_for_function( | ||||||
with_minibatch_steps | ||||||
).set( | ||||||
steps=self._grad_acc_steps, | ||||||
metric_accumulator=self._metric_accumulator, | ||||||
) | ||||||
return cfg | ||||||
|
||||||
|
||||||
class RematSpecModifier(ConfigModifier): | ||||||
"""Update the remat policies for specified modules.""" | ||||||
|
||||||
@config_class | ||||||
class Config(ConfigModifier.Config): | ||||||
remat_policies: Optional[Dict[str, RematSpec]] = None | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
def __init__(self, cfg: Config): | ||||||
super().__init__(cfg) | ||||||
cfg = self.config | ||||||
self._remat_policies = cfg.remat_policies | ||||||
|
||||||
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | ||||||
"""Update the remat policy for the specified modules. | ||||||
|
||||||
Args: | ||||||
cfg (SpmdTrainer.Config): the trainer config to be modified. | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Raises: | ||||||
ValueError: the target module is not found. | ||||||
ValueError: the remat_spec attribute is not found. | ||||||
|
||||||
Returns: | ||||||
The modified trainer config. | ||||||
""" | ||||||
if self._remat_policies is None: | ||||||
return cfg | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
for module_name, remat_spec in self._remat_policies.items(): | ||||||
# Here we assume x.y.z format. | ||||||
# One example would be model.decoder.transformer.layer. | ||||||
target_modules = module_name.split(".") | ||||||
curr_module = cfg | ||||||
for target_module in target_modules: | ||||||
if not hasattr(curr_module, target_module): | ||||||
raise ValueError(f"{target_module} is not found in {curr_module}.") | ||||||
curr_module = getattr(curr_module, target_module) | ||||||
# Here we assume all modules have remat_spec attribute. | ||||||
if not hasattr(curr_module, "remat_spec"): | ||||||
raise ValueError(f"{curr_module} does not have remat_spec attribute") | ||||||
curr_module.remat_spec = remat_spec | ||||||
return cfg | ||||||
|
||||||
|
||||||
class MeshShapeModifier(ConfigModifier): | ||||||
"""Update the mesh_shape for the trainer config.""" | ||||||
|
||||||
@config_class | ||||||
class Config(ConfigModifier.Config): | ||||||
mesh_shape: Optional[Union[MeshShape, HybridMeshShape]] = None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
|
||||||
def __init__(self, cfg: Config): | ||||||
super().__init__(cfg) | ||||||
cfg = self.config | ||||||
self._mesh_shape = cfg.mesh_shape | ||||||
|
||||||
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | ||||||
"""Overwrite the mesh shape. | ||||||
|
||||||
Args: | ||||||
cfg (SpmdTrainer.Config): the trainer config to be modified. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above. |
||||||
|
||||||
Returns: | ||||||
The modified trainer config. | ||||||
""" | ||||||
cfg.mesh_shape = self._mesh_shape | ||||||
return cfg | ||||||
|
||||||
|
||||||
class ChainConfigModifier(ConfigModifier): | ||||||
"""Chain multiple config modifiers together.""" | ||||||
|
||||||
@config_class | ||||||
class Config(ConfigModifier.Config): | ||||||
config_modifiers: Required[List[ConfigOr[ConfigModifier]]] = REQUIRED | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
def __init__(self, cfg: Config): | ||||||
super().__init__(cfg) | ||||||
cfg = self.config | ||||||
self._config_modifiers = [ | ||||||
cfg_modifier.instantiate() for cfg_modifier in cfg.config_modifiers | ||||||
markblee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
] | ||||||
|
||||||
def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: | ||||||
"""Chain multiple config modifiers together. | ||||||
The config modifiers will be applied in the order they are provided. | ||||||
|
||||||
Args: | ||||||
cfg (SpmdTrainer.Config): the trainer config to be modified. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
|
||||||
Returns: | ||||||
The modified trainer config. | ||||||
""" | ||||||
for config_modifier_fn in self._config_modifiers: | ||||||
cfg = config_modifier_fn(cfg) | ||||||
return cfg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright © 2023 Apple Inc. | ||
|
||
"""Test various ConfigModifier classes in trainer_config_modifier.py.""" | ||
|
||
import jax | ||
from absl.testing import absltest | ||
|
||
from axlearn.common import test_utils | ||
from axlearn.common.base_layer import RematSpec | ||
from axlearn.common.trainer import SpmdTrainer | ||
from axlearn.common.trainer_config_modifier import ( | ||
ChainConfigModifier, | ||
GradientAccumulation, | ||
MeshShapeModifier, | ||
RematSpecModifier, | ||
) | ||
from axlearn.common.trainer_test import DummyModel | ||
|
||
|
||
class GradientAccumulationTest(test_utils.TestCase): | ||
def test_gradient_accumulation_override(self): | ||
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) | ||
cfg_modifier = GradientAccumulation.default_config().set(grad_acc_steps=4).instantiate() | ||
cfg = cfg_modifier(cfg) | ||
self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4) | ||
|
||
|
||
class RematSpecModifierTest(test_utils.TestCase): | ||
def test_remat_policy_override(self): | ||
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) | ||
cfg_modifier = ( | ||
RematSpecModifier.default_config() | ||
.set( | ||
remat_policies={ | ||
"model.linear": RematSpec( | ||
prevent_cse=True, | ||
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, | ||
), | ||
} | ||
) | ||
.instantiate() | ||
) | ||
cfg = cfg_modifier(cfg) | ||
self.assertRegex(str(cfg.model.linear), "dots_saveable") | ||
cfg_modifier = ( | ||
RematSpecModifier.default_config() | ||
.set( | ||
remat_policies={ | ||
"model.linear": RematSpec( | ||
prevent_cse=True, | ||
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, | ||
), | ||
"model.unknown": RematSpec( | ||
prevent_cse=True, | ||
policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, | ||
), | ||
} | ||
) | ||
.instantiate() | ||
) | ||
# Ensure that the exception is working. | ||
with self.assertRaisesRegex(ValueError, "unknown is not found in.*"): | ||
_ = cfg_modifier(cfg) | ||
|
||
|
||
class MeshShapeModifierTest(test_utils.TestCase): | ||
def test_mesh_shape_update(self): | ||
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) | ||
cfg_modifier = MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)).instantiate() | ||
cfg = cfg_modifier(cfg) | ||
self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1)) | ||
|
||
|
||
class ChainConfigModifierTest(test_utils.TestCase): | ||
def test_chain_config_modifier(self): | ||
cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) | ||
cfg_modifier = ( | ||
ChainConfigModifier.default_config() | ||
.set( | ||
config_modifiers=[ | ||
GradientAccumulation.default_config().set(grad_acc_steps=4), | ||
MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)), | ||
] | ||
) | ||
.instantiate() | ||
) | ||
cfg = cfg_modifier(cfg) | ||
self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1)) | ||
self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(here and elsewhere)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the review, I need one more approval after the nit change.