forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunctional_sgd.py
152 lines (136 loc) · 5.35 KB
/
functional_sgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import Dict, List, Optional
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional SGD Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalSGD:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-2,
momentum: float = 0.0,
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
maximize: bool = False,
foreach: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
"momentum": momentum,
"dampening": dampening,
"weight_decay": weight_decay,
}
self.nesterov = nesterov
self.maximize = maximize
self.foreach = foreach
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
"""Similar to self.step, but operates on a single parameter and
its gradient.
"""
# TODO: Once step_param interface is robust, refactor step to call
# step param on each param.
weight_decay = self.defaults["weight_decay"]
momentum = self.defaults["momentum"]
dampening = self.defaults["dampening"]
lr = self.defaults["lr"]
params = [param]
momentum_buffer_list: List[Optional[Tensor]] = []
grads = []
has_sparse_grad = False
if grad is not None:
grads.append(grad)
if grad.is_sparse:
has_sparse_grad = True
if param not in self.state:
self.state[param] = {}
state = self.state[param]
if "momentum_buffer" not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state["momentum_buffer"])
with torch.no_grad():
F.sgd(
params,
grads,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=self.nesterov,
maximize=self.maximize,
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
)
# update momentum_buffer in state
state = self.state[param]
momentum_buffer = momentum_buffer_list[0]
if momentum_buffer is not None:
state["momentum_buffer"] = momentum_buffer
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
momentum_buffer_list: List[Optional[Tensor]] = []
lr = self.defaults["lr"]
weight_decay = self.defaults["weight_decay"]
momentum = self.defaults["momentum"]
dampening = self.defaults["dampening"]
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_sparse_grad = False
for param, gradient in zip(params, gradients):
if gradient is not None:
params_with_grad.append(param)
grads.append(gradient)
if gradient.is_sparse:
has_sparse_grad = True
if param not in self.state:
self.state[param] = {}
state = self.state[param]
if "momentum_buffer" not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state["momentum_buffer"])
with torch.no_grad():
F.sgd(
params_with_grad,
grads,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=self.nesterov,
maximize=self.maximize,
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
)
# update momentum_buffers in state
for i, p in enumerate(params_with_grad):
state = self.state[p]
momentum_buffer = momentum_buffer_list[i]
if momentum_buffer is not None:
state["momentum_buffer"] = momentum_buffer