forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlibrary.py
282 lines (239 loc) · 12.1 KB
/
library.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
from ._ops import OpOverload
from typing import Any, Optional, Set, List
import traceback
import torch
import weakref
__all__ = [
'Library',
'impl',
'define',
'fallthrough_kernel',
'impl_abstract',
'get_ctx',
]
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
# libraries calling into kernels not intended to be called.
_impls: Set[str] = set()
# prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim']
def fallthrough_kernel():
"""
A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
"""
raise NotImplementedError("fallthrough_kernel() should never be called.")
class Library:
"""
A class to create libraries that can be used to register new operators or
override operators in existing libraries from Python.
A user can optionally pass in a dispatch keyname if they only want to register
kernels corresponding to only one specific dispatch key.
To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
To create a new library (with name ns) to register new operators, set the kind to "DEF".
To create a fragment of a possibly existing library to register operators (and bypass
the limitation that there is only one library for a given namespace), set the kind to
"FRAGMENT".
Args:
ns: library name
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
dispatch_key: PyTorch dispatch key (default: "")
"""
def __init__(self, ns, kind, dispatch_key=""):
if kind not in ('IMPL', 'DEF', 'FRAGMENT'):
raise ValueError("Unsupported kind: ", kind)
if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'):
raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno
self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
self.ns = ns
self._op_impls: Set[str] = set()
self._registration_handles: List["torch._library.utils.RegistrationHandle"] = []
self.kind = kind
self.dispatch_key = dispatch_key
# Use a finalizer to setup the "destructor" instead of __del__.
# Python __del__ can lead to weird things (globals and locals may already
# be gone when __del__ actually gets called!). finalizers help the
# situation because it lets us capture references and keeps them alive
weakref.finalize(self, _del_library, _impls, self._op_impls, self._registration_handles)
def __repr__(self):
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
def define(self, schema, alias_analysis=""):
r'''Defines a new operator and its semantics in the ns namespace.
Args:
schema: function schema to define a new operator.
alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
inferred from the schema (default behavior) or not ("CONSERVATIVE").
Returns:
name of the operator as inferred from the schema.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LIBRARY)
>>> my_lib = Library("foo", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
'''
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
assert self.m is not None
return self.m.define(schema, alias_analysis)
def impl(self, op_name, fn, dispatch_key=''):
r'''Registers the function implementation for an operator defined in the library.
Args:
op_name: operator name (along with the overload) or OpOverload object.
fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
to register a fallthrough.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.
Example::
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
'''
if not callable(fn):
raise TypeError(f"Input function is required to be a callable but found type {type(fn)}")
if dispatch_key == '':
dispatch_key = self.dispatch_key
if isinstance(op_name, str):
name = op_name
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
if overload_name != '':
name = name + '.' + overload_name
else:
raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when impl is called but we error out before that)
raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
"'s behavior for {} dispatch key and {} namespace.".
format(name.split("::")[-1], dispatch_key, self.ns))
if dispatch_key == "Meta":
dispatcher_op_name = name
if '::' not in dispatcher_op_name:
dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
# Internally, we shouldn't be registering meta kernels for any operators that
# have CompositeImplicitAutograd kernels.
# Instead, we should be letting those decompositions run, and writing meta kernels
# only for the base operators.
if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
raise RuntimeError(
f"We should not register a meta kernel directly to the operator '{name}',"
" because it has a CompositeImplicitAutograd kernel in core."
" Instead we should let the operator decompose, and ensure that we have meta kernels"
" for the base ops that it decomposes into.")
assert self.m is not None
self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn)
_impls.add(key)
self._op_impls.add(key)
def _destroy(self):
self.m = None
for handle in self._registration_handles:
handle.destroy()
self._registration_handles.clear()
def _del_library(captured_impls, op_impls, registration_handles):
captured_impls -= op_impls
for handle in registration_handles:
handle.destroy()
# decorator to register python functions for library ops
# Note: this decorator API should remain consistent with `Library.impl` API
def impl(lib, name, dispatch_key=""):
def wrap(f):
lib.impl(name, f, dispatch_key)
return f
return wrap
def define(lib, schema, alias_analysis=""):
def wrap(f):
name = lib.define(schema, alias_analysis)
lib.impl(name, f)
return f
return wrap
def impl_abstract(name, func=None, *, lib=None, _stacklevel=1):
r"""Register an abstract implementation for this operator.
An "abstract implementation" specifies the behavior of this operator on
Tensors that carry no data. Given some input Tensors with certain properties
(sizes/strides/storage_offset/device), it specifies what the properties of
the output Tensors are.
The abstract implementation has the same signature as the operator.
It is run for both FakeTensors and meta tensors. To write an abstract
implementation, assume that all Tensor inputs to the operator are
regular CPU/CUDA/Meta tensors, but they do not have storage, and
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
The abstract implementation must consist of only PyTorch operations
(and may not directly access the storage or data of any input or
intermediate Tensors).
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please see
https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit
Examples::
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> lib = torch.library.Library("mylibrary", "FRAGMENT")
>>> lib.define("mylibrary::custom_linear(Tensor x, Tensor weight, Tensor bias)"
>>>
>>> @torch.library.impl_abstract("mylibrary::custom_linear"):
>>> def custom_linear_abstract(x, weight):
>>> assert x.dim() == 2
>>> assert weight.dim() == 2
>>> assert bias.dim() == 1
>>> assert x.shape[1] == weight.shape[1]
>>> assert weight.shape[0] == bias.shape[0]
>>> assert x.device == weight.device
>>>
>>> return (x @ weight.t()) + bias
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> lib = torch.library.Library("mylibrary", "FRAGMENT")
>>> lib.define("mylibrary::custom_nonzero(Tensor x) -> Tensor"
>>>
>>> @torch.library.impl_abstract("mylibrary::custom_nonzero"):
>>> def custom_nonzero_abstract(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch.library.get_ctx()
>>> nnz = ctx.new_dynamic_size()
>>> shape = [x.dim(), nnz]
>>> result = x.new_empty(shape, dtype=torch.long)
>>> return result
>>>
>>> @torch.library.impl(lib, "custom_nonzero", "CPU")
>>> def custom_nonzero_cpu(x):
>>> x_np = to_numpy(x)
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
>>> # constrain the range to at least 2
>>> if res.shape[0] <= 1:
>>> raise RuntimeError("not supported")
>>> return torch.tensor(res, device=x.device)
"""
source = torch._library.utils.get_source(_stacklevel + 1)
def inner(func):
entry = torch._library.simple_registry.singleton.find(name)
handle = entry.abstract_impl.register(func, source)
if lib is not None:
lib._registration_handles.append(handle)
return func
if func is None:
return inner
return inner(func)
# NOTE [ctx inside the fake implementation]
# If a user has an operator with data-dependent output shape, then when writing
# a fake implementation they must query the current ctx and use methods on the
# ctx to construct a new unbacked symint.
#
# This is done via us setting the global_ctx_getter function every time a fake
# implementation is invoked.
def get_ctx() -> "torch._library.abstract_impl.AbstractImplCtx":
"""get_ctx() returns the current AbstractImplCtx object.
Calling ``get_ctx()`` is only valid inside of an abstract impl.
"""
return torch._library.abstract_impl.global_ctx_getter()