forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterpreter.py
505 lines (427 loc) · 20.9 KB
/
interpreter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
from .graph_module import GraphModule
from .graph import Graph
from .node import Argument, Node, Target, map_arg, map_aggregate
from .proxy import Proxy
from ._symbolic_trace import Tracer
from ._compatibility import compatibility
from . import config
import torch.fx.traceback as fx_traceback
import torch
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import inspect
from contextlib import contextmanager
from torch.hub import tqdm
__all__ = ['Interpreter', 'Transformer']
@compatibility(is_backward_compatible=True)
class Interpreter:
"""
An Interpreter executes an FX graph Node-by-Node. This pattern
can be useful for many things, including writing code
transformations as well as analysis passes.
Methods in the Interpreter class can be overridden to customize
the behavior of execution. The map of overrideable methods
in terms of call hierarchy::
run()
+-- run_node
+-- placeholder()
+-- get_attr()
+-- call_function()
+-- call_method()
+-- call_module()
+-- output()
Example:
Suppose we want to swap all instances of ``torch.neg`` with
``torch.sigmoid`` and vice versa (including their ``Tensor``
method equivalents). We could subclass Interpreter like so::
class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target : Target,
args : Tuple, kwargs : Dict) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
def call_method(self, target : Target,
args : Tuple, kwargs : Dict) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
Args:
module (GraphModule): The module to be executed
garbage_collect_values (bool): Whether to delete values after their last
use within the Module's execution. This ensures optimal memory usage during
execution. This can be disabled to, for example, examine all of the intermediate
values in the execution by looking at the ``Interpreter.env`` attribute.
"""
@compatibility(is_backward_compatible=True)
def __init__(self, module : GraphModule, garbage_collect_values : bool = True):
assert isinstance(module, GraphModule)
self.module = module
self.submodules = dict(self.module.named_modules())
self.env : Dict[Node, Any] = {}
self.name = "Interpreter"
self.garbage_collect_values = garbage_collect_values
self.extra_traceback = True
if self.garbage_collect_values:
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use : Dict[Node, Node] = {}
self.user_to_last_uses : Dict[Node, List[Node]] = {}
def register_last_uses(n : Node, user : Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
self.user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(self.module.graph.nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
"""
Run `module` via interpretation and return the result.
Args:
*args: The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
This is a dict mapping `Node` to any value. This can be used, for example, to
pre-populate results for certain `Nodes` so as to do only partial evaluation within
the interpreter.
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
process_outputs function first before using them.
Returns:
Any: The value returned from executing the Module
"""
self.env = initial_env if initial_env is not None else {}
# Positional function args are consumed left-to-right by
# `placeholder` nodes. Use an iterator to keep track of
# position and extract those values.
if enable_io_processing:
args = self.module.graph.process_inputs(*args)
self.args_iter : Iterator[Any] = iter(args)
pbar = tqdm(total=len(self.module.graph.nodes),
desc=f"{self.name}: {str(list(self.module.graph.nodes)) if config.verbose_progress else ''}",
initial=0, position=0, leave=True, disable=config.disable_progress, delay=0)
for node in self.module.graph.nodes:
pbar.update(1)
if node in self.env:
# Short circuit if we have this value. This could
# be used, for example, for partial evaluation
# where the caller has pre-populated `env` with
# values for a subset of the program.
continue
try:
self.env[node] = self.run_node(node)
except Exception as e:
if self.extra_traceback:
msg = f"While executing {node.format_node()}"
msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg)
msg += f"\nOriginal traceback:\n{node.stack_trace}"
e.args = (msg,) + e.args[1:]
if isinstance(e, KeyError):
raise RuntimeError(*e.args) from e
raise
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
if node.op == 'output':
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
@compatibility(is_backward_compatible=True)
def boxed_run(self, args_list):
"""
Run `module` via interpretation and return the result. This uses the "boxed"
calling convention, where you pass a list of arguments, which will be cleared
by the interpreter. This ensures that input tensors are promptly deallocated.
"""
args_iter = iter(args_list)
env = {}
for n in self.module.graph.nodes:
if n.op == "placeholder":
env[n] = next(args_iter)
args_list.clear()
return self.run(initial_env=env)
@contextmanager
def _set_current_node(self, node):
with fx_traceback.set_current_meta(node):
yield
@compatibility(is_backward_compatible=True)
def run_node(self, n : Node) -> Any:
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``
Args:
n (Node): The Node to execute
Returns:
Any: The result of executing ``n``
"""
with self._set_current_node(n):
args, kwargs = self.fetch_args_kwargs_from_env(n)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
return getattr(self, n.op)(n.target, args, kwargs)
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Returns:
Any: The argument value that was retrieved.
"""
assert isinstance(target, str)
if target.startswith('*'):
# For a starred parameter e.g. `*args`, retrieve all
# remaining values from the args list.
return list(self.args_iter)
else:
try:
return next(self.args_iter)
except StopIteration as si:
if len(args) > 0:
return args[0]
else:
raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si
@compatibility(is_backward_compatible=True)
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The value of the attribute that was retrieved
"""
assert isinstance(target, str)
return self.fetch_attr(target)
@compatibility(is_backward_compatible=True)
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the function invocation
"""
assert not isinstance(target, str)
# Execute the function and return the result
return target(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# Execute the method and return the result
assert isinstance(target, str)
return getattr(self_obj, target)(*args_tail, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the module invocation
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return submod(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The return value referenced by the output node
"""
return args[0]
# Helper methods
@compatibility(is_backward_compatible=True)
def fetch_attr(self, target : str):
"""
Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
Args:
target (str): The fully-qualified name of the attribute to fetch
Return:
Any: The value of the attribute.
"""
target_atoms = target.split('.')
attr_itr = self.module
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
@compatibility(is_backward_compatible=True)
def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
"""
Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
from the current execution environment.
Args:
n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
Return:
Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
"""
args = self.map_nodes_to_values(n.args, n)
assert isinstance(args, tuple)
kwargs = self.map_nodes_to_values(n.kwargs, n)
assert isinstance(kwargs, dict)
return args, kwargs
@compatibility(is_backward_compatible=True)
def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
"""
Recursively descend through ``args`` and look up the concrete value
for each ``Node`` in the current execution environment.
Args:
args (Argument): Data structure within which to look up concrete values
n (Node): Node to which ``args`` belongs. This is only used for error reporting.
"""
def load_arg(n_arg : Node) -> Any:
if n_arg not in self.env:
raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
f'to diagnose such issues')
return self.env[n_arg]
return map_arg(args, load_arg)
@compatibility(is_backward_compatible=True)
class Transformer(Interpreter):
"""
``Transformer`` is a special type of interpreter that produces a
new ``Module``. It exposes a ``transform()`` method that returns
the transformed ``Module``. ``Transformer`` does not require
arguments to run, as ``Interpreter`` does. ``Transformer`` works
entirely symbolically.
Example:
Suppose we want to swap all instances of ``torch.neg`` with
``torch.sigmoid`` and vice versa (including their ``Tensor``
method equivalents). We could subclass ``Transformer`` like so::
class NegSigmSwapXformer(Transformer):
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(n)
def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
if target == 'neg':
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(n)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
Args:
module (GraphModule): The ``Module`` to be transformed.
"""
@compatibility(is_backward_compatible=True)
def __init__(self, module):
super().__init__(module)
self.new_graph = Graph()
self.new_graph.set_codegen(module.graph._codegen)
class TransformerTracer(Tracer):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment]
def is_leaf_module(self, _, __) -> bool:
return True
self.tracer = TransformerTracer(self.new_graph)
self.tracer.root = module
@compatibility(is_backward_compatible=True)
def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
"""
Execute a ``placeholder`` node. In ``Transformer``, this is
overridden to insert a new ``placeholder`` into the output
graph.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
"""
assert isinstance(target, str)
default_value = next(iter(args)) if args else inspect.Signature.empty
return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
@compatibility(is_backward_compatible=True)
def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
"""
Execute a ``get_attr`` node. In ``Transformer``, this is
overridden to insert a new ``get_attr`` node into the output
graph.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
"""
assert isinstance(target, str)
return self.tracer.create_proxy("get_attr", target, args, kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
# Override so that the leaf module policy from `self.tracer` is respected.
assert isinstance(target, str)
submod = self.fetch_attr(target)
return self.tracer.call_module(submod, submod.forward, args, kwargs)
@compatibility(is_backward_compatible=True)
def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
# Override so that functions that were wrapped are still wrapped.
return self.tracer.create_proxy('call_function', target, args, kwargs)
@compatibility(is_backward_compatible=True)
def transform(self) -> GraphModule:
"""
Transform ``self.module`` and return the transformed
``GraphModule``.
"""
with fx_traceback.preserve_node_meta():
result = super().run(enable_io_processing=False)
if result is not None:
def strip_proxy(a : Union[Argument, Proxy]) -> Any:
return a.node if isinstance(a, Proxy) else a
self.new_graph.output(map_aggregate(result, strip_proxy))
return GraphModule(self.module, self.new_graph)