forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheckpoint_activation.py
94 lines (79 loc) · 3.26 KB
/
checkpoint_activation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from contextlib import contextmanager, nullcontext
from typing import Any, Tuple
import torch
import torch.nn as nn
from torch.utils.checkpoint import (
_checkpoint_without_reentrant_generator,
_DEFAULT_DETERMINISM_MODE,
)
from .contract import contract
@contextmanager
def _no_hook(module: nn.Module):
r"""
Disable hooks installed by checkpoint to avoid unintentional recursion
during backward recomputation.
"""
orig_enable_hook = checkpoint.state(module).enable_hook
checkpoint.state(module).enable_hook = False
try:
yield
finally:
checkpoint.state(module).enable_hook = orig_enable_hook
@contract()
def checkpoint(module: nn.Module) -> nn.Module:
r"""
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing model
source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs,
this one does not modify model structure or fully-qualified names either.
Under the hood, it registers activation checkpointing logic as pre- and
post-forward hooks. Hence, this API can be easily applied to any model or
sub-modules in the model.
Args:
module (nn.Module): the target model or sub-module to apply activation
checkpointing.
Example::
>>> # xdoctest: +SKIP
>>> import torch.nn as nn
>>>
>>> class MyModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.l1 = nn.Linear(10, 10)
>>> self.l2 = nn.Linear(10, 10)
>>>
>>> def forward(self, x):
>>> return self.l2(self.l1(x))
>>>
>>> model = MyModel()
>>> checkpoint(model.l1) # apply activation checkpointing only to l1
>>> model(torch.zeros(2, 10)).sum().backward()
"""
torch._C._log_api_usage_once("torch.distributed.checkpoint")
def forward_pre_hook(module: nn.Module, inputs: Tuple[Any, ...]) -> None:
if checkpoint.state(module).enable_hook:
def context_fns():
return nullcontext(), _no_hook(module)
checkpoint.state(
module
)._ac_generator = _checkpoint_without_reentrant_generator(
module, True, context_fns, _DEFAULT_DETERMINISM_MODE, False, *inputs
)
next(checkpoint.state(module)._ac_generator)
def forward_hook(module: nn.Module, inputs: Tuple[Any, ...], output: Any) -> Any:
if checkpoint.state(module).enable_hook:
try:
next(checkpoint.state(module)._ac_generator)
except StopIteration:
pass
else:
raise RuntimeError(
"Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!"
)
# Ensure that we no longer hold on to the generator. always_call=True helps ensure we
# clear this even in the case of exception in fwd pass.
checkpoint.state(module)._ac_generator = None
checkpoint.state(module).enable_hook = True
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_hook, prepend=True, always_call=True)
return module