forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfunctional.py
173 lines (144 loc) · 7.47 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
from torch.library import Library
from torch._ops import OpOverload
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseTy, BaseType
from torch._C import _ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
from .autograd import autograd_not_implemented
import torch.utils._pytree as pytree
import weakref
def register_functional_op(
lib: Library,
new_op_name: str,
mutable_op: OpOverload,
) -> None:
"""Given a mutable operator, registers the functional variant.
This API also correctly links the functional variant with the mutable
operator for the purposes of functionalization.
All of the new registrations are performed on the ``lib`` passed in.
Arguments:
lib (Library): Should be a torch.library.Library object that has
the same namespace as ``mutable_op``'s namespace.
lib will be used to register the new functional op as well
as a functionalization kernel for the ``mutable_op``
If you don't have a library handy, use
``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
new_op_name (str): The name of the functional operator (without the
namespace). If no namespace, the new functional variant will be
accessible under ``torch.ops.{lib.ns}.new_op_name``.
mutable_op (OpOverload): The mutable custom operator. Note
that you may need to add a `.default` to it, like
`torch.ops.aten.abs_.default`.
"""
validate(mutable_op)
schema = functional_schema(new_op_name, mutable_op)
lib.define(schema)
functional_impl = construct_functional_impl(mutable_op)
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
# There's no easy way for us to generate the autograd kernel, so we
# use autograd_not_implemented. Also, this makes it so that the user
# is unable to register an autograd formula themselves. This shouldn't
# be a problem if the user doesn't use the functional op direclty
# in their program, but we may need to revist this in the future.
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
lib.impl(mutable_op, f_kernel, 'Functionalize')
def construct_functional_impl(mutable_op):
def functional_impl(*args):
# Strategy:
# - clone args that would have been mutated
# - run mutable_op
# - return the cloned args as additional outputs
new_args = []
extra_rets = []
for is_write, arg in zip(mutable_args(mutable_op), args):
if is_write:
cloned = arg.clone()
new_args.append(cloned)
extra_rets.append(cloned)
else:
new_args.append(arg)
result = mutable_op(*new_args)
if result is None:
return tuple(extra_rets)
if isinstance(result, tuple):
return (*result, *extra_rets)
return (result, *extra_rets)
return functional_impl
def construct_functionalization_kernel(mutable_op, functional_op):
def kernel(*args):
# There's nothing to be functionalized!
# We can still end up here because DispatchKey::Functionalize is a mode key
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
return mutable_op(*args)
# NB: This differs from the codegen -- codegen handles cases where there
# are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
# This only really matters for XLA (mixed CPU-XLA tensors) and
# running functionalization without the PT2 stack (which guarantees to us that
# all tensors are FunctionalTensorWrapper).
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
unwrapped_args = []
for arg in args:
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
torch._sync(arg)
unwrapped = torch._from_functional_tensor(arg)
unwrapped_args.append(unwrapped)
else:
unwrapped_args.append(arg)
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
output = functional_op(*unwrapped_args)
num_actual_output = len(mutable_op._schema.returns)
actual_output = pytree.tree_map(
torch._to_functional_tensor, output[:num_actual_output])
new_values_to_propagate = output[num_actual_output:]
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
if is_write]
assert len(new_values_to_propagate) == len(inputs_to_replace)
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
torch._C._propagate_xla_data(arg, new_value)
torch._C._replace_(arg, new_value)
torch._C._commit_update(arg)
torch._sync(arg)
if len(actual_output) == 1:
return actual_output[0]
elif len(actual_output) == 0:
return None
return actual_output
return kernel
def validate(mutable_op: OpOverload):
if not isinstance(mutable_op, OpOverload):
raise TypeError(
f"register_functional_op(mutable_op): expected mutable_op to be instance of "
f"OpOverload but got {type(mutable_op)}")
# There are generally three types of "in-place" or "mutable" ops.
# Each of them have their own conventions:
# - inplace (first input modified in-place and returned as only output)
# - out= (some args modified in-place and returned as outputs)
# - mutable (some args modified in-place but none of those returned as outputs)
# In theory we can support all three, but we'll just support the last
# option right now for simplicity.
schema = FunctionSchema.parse(str(mutable_op._schema))
if not schema.kind() == SchemaKind.mutable:
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
for ret in schema.returns:
# construct_functionalization_kernel assumes this for simplicity
if ret.annotation is not None:
raise NotImplementedError(
"NYI: register_functional_op(op) where op returns a mutated or aliased value. "
"Please file an issue (and as a workaround, modify your operator to "
"not return the mutated value or aliases)")
for arg in schema.arguments.flat_all:
# construct_functionalization_kernel assumes this for simplicity
if arg.type.is_tensor_like() and arg.type != BaseType(BaseTy.Tensor):
raise NotImplementedError(
"NYI: register_functional_op(op) where op accepts Optional or List of tensors."
"Please file an issue.")
def functional_schema(new_op_name, op: OpOverload):
schema = FunctionSchema.parse(str(op._schema))
schema = schema.signature().with_name(OperatorName.parse(new_op_name))
return str(schema)
def mutable_args(op: OpOverload):
return tuple(False if arg.alias_info is None else arg.alias_info.is_write
for arg in op._schema.arguments)