forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunctional_adam.py
191 lines (172 loc) · 6.85 KB
/
functional_adam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from typing import Dict, List, Optional, Tuple
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional Adam Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalAdam:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
maximize: bool = False,
foreach: bool = False,
fused: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.defaults = {
"lr": lr,
"eps": eps,
"beta1": betas[0],
"beta2": betas[1],
"weight_decay": weight_decay,
}
self.amsgrad = amsgrad
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
"""
Similar to step, but operates on a single parameter and optionally a
gradient tensor.
"""
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
if grad is not None:
params_with_grad.append(param)
grads.append(grad)
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if self.amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
with torch.no_grad():
F.adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
maximize=self.maximize,
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
eps=self.defaults["eps"],
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
for param, gradient in zip(self.param_group["params"], gradients):
if gradient is not None:
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if self.amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
with torch.no_grad():
F.adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
maximize=self.maximize,
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
eps=self.defaults["eps"],
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)