forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathautograd.py
274 lines (236 loc) · 11.5 KB
/
autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import torch
import torch.utils._pytree as pytree
from collections import namedtuple
import functools
# NOTE [CustomOp autograd kernel indirection]
# We register `inner` as the autograd kernel for this custom_op.
# `inner` either calls the autograd formula registered by the user,
# or goes into an `autograd_not_implemented` kernel.
#
# The reason why this indirection exists is
# so that we can swap out the autograd kernel (the PyTorch dispatcher
# doesn't actually allow us to do this). By default, we want
# the `autograd_not_implemented` behavior, but then the user may come
# and register something that is actually a backward formula
def autograd_kernel_indirection(custom_op):
autograd_fallback = autograd_not_implemented(custom_op)
def inner(*args, **kwargs):
if custom_op._has_impl('autograd'):
kernel = custom_op._get_impl('autograd').func
return kernel(*args, **kwargs)
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
# after the user gives us "backward" and "save_for_backward", we generate
# the "autograd" impl. If the user only provided one, then we tell
# the user they've done something wrong.
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
missing = (
'save_for_backward' if custom_op._has_impl('backward')
else 'backward'
)
found = 'save_for_backward' if missing == 'backward' else 'backward'
loc = custom_op._get_impl(found).location
raise RuntimeError(
f"We found a '{found}' registration for {custom_op} at "
f"{loc} but were unable to find a '{missing}' registration. "
f"To use the CustomOp API to register a backward formula, "
f"please provide us both a backward function and a "
f"'save for backward' function via `impl_backward` and "
f"`impl_save_for_backward` respectively.")
return autograd_fallback(*args, **kwargs)
return inner
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
# or change the default autograd fallback to the autograd not implemented fallback.
def autograd_not_implemented(custom_op):
def kernel(*args, **kwargs):
if torch.is_grad_enabled() and pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
with torch._C._AutoDispatchBelowAutograd():
return custom_op(*args, **kwargs)
return kernel
def mark_non_differentiable(ctx, output, output_differentiability):
# Output types are restricted to be:
# - Tensor
# - Tensor[]
# - int, bool, Scalar, float
# See _check_can_register_backward
if output_differentiability is not None:
if not isinstance(output, tuple):
tuple_output = (output,)
else:
tuple_output = output # type: ignore[assignment]
assert len(output_differentiability) == len(tuple_output)
non_differentiable_tensors = []
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
if isinstance(out, torch.Tensor):
if not differentiable:
non_differentiable_tensors.append(out)
continue
if isinstance(out, list):
if not differentiable:
non_differentiable_tensors.extend(out)
continue
if differentiable:
raise RuntimeError(
f"With output_differentiability={output_differentiability}. "
f"At idx {idx}, we received an object of type {type(out)} that "
f"is not a Tensor, so it cannot have be marked as differentiable in "
f"output_differentiability.")
if non_differentiable_tensors:
ctx.mark_non_differentiable(*non_differentiable_tensors)
def construct_autograd_kernel(
schema,
output_differentiability,
custom_op,
op_overload,
save_for_backward_fn,
backward_fn):
def apply(*args):
flat_args, spec = pytree.tree_flatten(args)
out_spec = None
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
with torch._C._AutoDispatchBelowAutograd():
output = op_overload(*args)
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(
schema, pytree.tree_map(lambda arg: type(arg), args))
save_for_backward_fn_inputs = namedtuple_args(schema, args)
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
save_pytree_for_backward(ctx, (to_save, args_info))
mark_non_differentiable(ctx, output, output_differentiability)
nonlocal out_spec
flat_output, out_spec = pytree.tree_flatten(output)
return tuple(flat_output)
def backward(ctx, *flat_grad_output):
assert out_spec is not None
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
saved, args_info = unpack_saved(ctx)
# There is nothing on the ctx object for now, it is just there so
# that we can add additional things in the future.
inner_ctx = object()
if not isinstance(grads, tuple):
grads = (grads,)
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
# Massage the grad_inputs_dict to a form acceptable by
# autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function(
custom_op._opname + '_customop', forward, backward)
flat_output = generated_cls.apply(*flat_args)
assert out_spec is not None
return pytree.tree_unflatten(list(flat_output), out_spec)
return apply
def gen_autograd_function(name, forward, backward):
generated_cls = type(
name,
(torch.autograd.Function,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
}
)
return generated_cls
@functools.lru_cache
def namedtuple_args_cls(schema):
attribs = [arg.name for arg in schema.arguments.flat_all]
name = str(schema.name) + "_args"
# mypy doesn't support dynamic namedtuple name
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
return tuple_cls
def namedtuple_args(schema, args):
assert isinstance(args, tuple)
tuple_cls = namedtuple_args_cls(schema)
return tuple_cls(*args)
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
def error(what):
backward = forward_op._get_impl('backward')
raise RuntimeError(
f"In the backward function defined for {forward_op} at "
f"{backward.location} using the CustomOp API, {what}")
if not isinstance(grad_inputs_dict, dict):
error(f"expected the output of the backward function to be a dict but "
f"got {type(grad_inputs_dict)}")
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
if arg.type.is_tensor_like()}
actual_keys = grad_inputs_dict.keys()
if expected_keys != actual_keys:
error(f"expected the returned grad_input dict to have keys "
f"{expected_keys} but got {actual_keys}. The backward "
f"function must return a gradient (can be None) for each arg "
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
f"Args declared to be non-Tensor-like types should not appear "
f"in the grad_input dict")
for name, grad in grad_inputs_dict.items():
arg_info = getattr(args_info, name)
if isinstance(arg_info, list):
if not isinstance(grad, (tuple, list)):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of gradients but got object of type "
f"{type(grad)}.")
if not len(grad) == len(arg_info):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got "
f"{len(grad)}")
for idx, (g, info) in enumerate(zip(grad, arg_info)):
if g is None:
continue
if not isinstance(g, torch.Tensor):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of None or Tensor gradients but got "
f"object of {type(g)} at index {idx}")
if not issubclass(info, torch.Tensor):
error(f"for input '{name}', got a Tensor as the gradient "
f"for the {idx}-th value but expected None because "
f"the {idx}-th value was not a Tensor (it was "
f"type {arg_info}")
continue
if grad is None:
continue
if not isinstance(grad, torch.Tensor):
error(f"got object of type {type(grad)} as the gradient for input "
f"'{name}', "
f"but expected the gradient to be either None or a Tensor")
if not issubclass(arg_info, torch.Tensor):
error(f"got a Tensor as the gradient for input '{name}' but "
f"expected None as the gradient because input '{name}' "
f"was not a Tensor (it was type {arg_info}).")
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
result = []
for name, arg_info in args_info._asdict().items():
if name not in grad_inputs_dict:
result.append(pytree.tree_map(lambda x: None, arg_info))
continue
result.append(grad_inputs_dict[name])
return tuple(pytree.tree_flatten(result)[0])
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
# autograd.Function prefers that users use ctx.save_for_backward to
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
# ctx object.
def save_pytree_for_backward(ctx, stuff):
flat_stuff, spec = pytree.tree_flatten(stuff)
num_elts = len(flat_stuff)
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if isinstance(thing, torch.Tensor)]
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if not isinstance(thing, torch.Tensor)]
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
ctx.spec = spec
ctx.num_elts = num_elts
ctx.save_for_backward(*tensors)
ctx.tensor_idxs = tensor_idxs
ctx.saved_non_tensors = non_tensors
ctx.non_tensor_idxs = non_tensor_idxs
# Inverse operation to save_pytree_for_backward
def unpack_saved(ctx):
flat_stuff = [None] * ctx.num_elts
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
flat_stuff[idx] = tensor
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
flat_stuff[idx] = non_tensor
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
return stuff