-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient Accumulation in Axlearn #465
Conversation
ade2229
to
3394936
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why this is needed? We find usually it's more efficient to use either a larger mesh or a smaller batch size.
@@ -444,6 +444,153 @@ def _mask_tree(tree: dict, *, keep: dict) -> dict: | |||
) | |||
|
|||
|
|||
class MetricsAccumulationOp(NamedTuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Axlearn already has metric accumulation classes that are used by evalers. Could those be reused here instead of defining new classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These metric accumulation classes are stateless so that they are usable as carry by jax.lax.scan unlike the ones in the evaler, I can make the class structure similar though.
# tuple of key-value pairs specifying custom aggregation and normalization | ||
# for a specific metric | ||
metrics_accumulation_key_ops: Sequence[Dict[str, Optional[MetricsAccumulationOp]]] = [] | ||
gradient_dtype: Optional[jnp.dtype] = jnp.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the existing learner class use this? If not, we should try to be consistent with its API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please be more specific, is the concern naming of members?
|
||
Returns: | ||
ForwardBackwardOutputs: pytree containing gradients and metrics | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if instead of having a separate learner for microbatching, it would be more flexible to have a generic way of wrapping a ForwardFn so that it uses Jax.lax.map to run the microbatches. Beyond avoiding the need to add a new learner, it would also allow for other microbetching uses outside of learner, eg inference or in second order optimizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.lax.map gives no guarantees of sequential execution of microbatches which is the key quality of gradient accumulation.
I’m ooo this week. I have left some preliminary comments for now. |
3394936
to
2430830
Compare
Closing since gradient accumulation functionality has been implemented via #614 |
Gradient accumulation allows training with higher batch sizes without scaling out.
Added a new learner type
learner.klass: 'axlearn.common.learner.AccumulatedLearner'
At a high level the optimization does the following:
Configuration changes:
micriobatches
in the learner.