forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubgraph_rewriter.py
339 lines (270 loc) · 13.1 KB
/
subgraph_rewriter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from ._symbolic_trace import symbolic_trace
from ._compatibility import compatibility
import copy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union
import torch
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
@compatibility(is_backward_compatible=False)
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
gm.delete_all_unused_submodules()
if isinstance(replacement, GraphModule):
replacement.graph.lint()
def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
module_path, _, attr_name = target.rpartition(".")
mod: torch.nn.Module = gm.get_submodule(module_path)
attr = getattr(mod, attr_name, None)
return attr
for node in gm.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":
gm_attr = try_get_attr(gm, node.target)
replacement_attr = try_get_attr(replacement, node.target)
# CASE 1: This target already exists as an attribute in our
# result GraphModule. Whether or not it exists in
# `replacement`, the existing submodule takes precedence.
if gm_attr is not None:
continue
# CASE 2: The target exists as an attribute in `replacement`
# only, so we need to copy it over.
elif replacement_attr is not None:
new_attr = copy.deepcopy(replacement_attr)
if isinstance(replacement_attr, torch.nn.Module):
gm.add_submodule(node.target, new_attr)
else:
setattr(gm, node.target, new_attr)
# CASE 3: The target doesn't exist as an attribute in `gm`
# or `replacement`
else:
raise RuntimeError("Attempted to create a \"", node.op,
"\" node during subgraph rewriting "
f"with target {node.target}, but "
"the referenced attribute does not "
"exist in the replacement GraphModule")
gm.graph.lint()
@compatibility(is_backward_compatible=True)
def replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule]
) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule
(``gm``), then replaces each of these matched subgraphs with another
subgraph (``replacement``).
Args:
``gm``: The GraphModule that wraps the Graph to operate on
``pattern``: The subgraph to match in ``gm`` for replacement
``replacement``: The subgraph to replace ``pattern`` with
Returns:
List[Match]: A list of ``Match`` objects representing the places
in the original graph that ``pattern`` was matched to. The list
is empty if there are no matches. ``Match`` is defined as:
.. code-block:: python
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
Examples:
.. code-block:: python
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
The above code will first match ``pattern`` in the ``forward``
method of ``traced_module``. Pattern-matching is done based on
use-def relationships, not node names. For example, if you had
``p = torch.cat([a, b])`` in ``pattern``, you could match
``m = torch.cat([a, b])`` in the original ``forward`` function,
despite the variable names being different (``p`` vs ``m``).
The ``return`` statement in ``pattern`` is matched based on its
value only; it may or may not match to the ``return`` statement in
the larger graph. In other words, the pattern doesn't have to extend
to the end of the larger graph.
When the pattern is matched, it will be removed from the larger
function and replaced by ``replacement``. If there are multiple
matches for ``pattern`` in the larger function, each non-overlapping
match will be replaced. In the case of a match overlap, the first
found match in the set of overlapping matches will be replaced.
("First" here being defined as the first in a topological ordering
of the Nodes' use-def relationships. In most cases, the first Node
is the parameter that appears directly after ``self``, while the
last Node is whatever the function returns.)
One important thing to note is that the parameters of the
``pattern`` Callable must be used in the Callable itself,
and the parameters of the ``replacement`` Callable must match
the pattern. The first rule is why, in the above code block, the
``forward`` function has parameters ``x, w1, w2``, but the
``pattern`` function only has parameters ``w1, w2``. ``pattern``
doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
As an example of the second rule, consider replacing
.. code-block:: python
def pattern(x, y):
return torch.neg(x) + torch.relu(y)
with
.. code-block:: python
def replacement(x, y):
return torch.relu(x)
In this case, ``replacement`` needs the same number of parameters
as ``pattern`` (both ``x`` and ``y``), even though the parameter
``y`` isn't used in ``replacement``.
After calling ``subgraph_rewriter.replace_pattern``, the generated
Python code looks like this:
.. code-block:: python
def forward(self, x, w1, w2):
stack_1 = torch.stack([w1, w2])
sum_1 = stack_1.sum()
stack_2 = torch.stack([w1, w2])
sum_2 = stack_2.sum()
max_1 = torch.max(sum_1)
add_1 = x + max_1
max_2 = torch.max(sum_2)
add_2 = add_1 + max_2
return add_2
"""
match_and_replacements = _replace_pattern(gm, pattern, replacement)
return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
# Experimental API, not backward compatible
@compatibility(is_backward_compatible=False)
def replace_pattern_with_filters(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
"""
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
Args:
``match_filters``: A list of functions that take in
(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
whether the match satisfies the condition.
See matcher_utils.py for definition of InternalMatch.
"""
return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals)
def _replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
if match_filters is None:
match_filters = []
# Get the graphs for `gm`, `pattern`, `replacement`
original_graph: Graph = gm.graph
if isinstance(pattern, GraphModule):
pattern_graph = pattern.graph
else:
pattern_graph = symbolic_trace(pattern).graph
if isinstance(replacement, GraphModule):
replacement_graph = replacement.graph
else:
replacement_graph = symbolic_trace(replacement).graph
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
remove_overlapping_matches=True, ignore_literals=ignore_literals)
_matches: List[InternalMatch] = matcher.match(original_graph)
# Filter out matches that don't match the filter
_matches = [
m for m in _matches
if all(match_filter(m, original_graph, pattern_graph)
for match_filter in match_filters)
]
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
# As we progressively replace nodes, we'll need to keep track of how the match results should change
match_changed_node: Dict[Node, Node] = {}
match_and_replacements = []
for match in _matches:
# Build connecting between replacement graph's input and original graph input producer node
# Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph`
assert len(match.placeholder_nodes) == len(replacement_placeholders)
val_map: Dict[Node, Node] = {}
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
if isinstance(gn, Node):
val_map[rn] = match_changed_node.get(gn, gn)
if gn != val_map[rn]:
# Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
gn_ind = match.placeholder_nodes.index(gn)
match.placeholder_nodes[gn_ind] = match_changed_node[gn]
map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)]
match.nodes_map[map_key] = match_changed_node[gn]
else:
val_map[rn] = gn
# Copy the replacement graph over
user_nodes: Set[Node] = set()
for n in match.returning_nodes:
for user in n.users:
user_nodes.add(user)
assert user_nodes, "The returning_nodes should have at least one user node"
if len(user_nodes) == 1:
first_user_node = list(user_nodes)[0]
else:
# If there are multiple user nodes, we need to find the first user node
# in the current execution order of the `original_graph`
for n in original_graph.nodes:
if n in user_nodes:
first_user_node = n
break
with original_graph.inserting_before(first_user_node):
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
if isinstance(copied_returning_nodes, Node):
copied_returning_nodes = (copied_returning_nodes, )
# Get a list of nodes that have been replaced into the graph
replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes]
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location
assert len(match.returning_nodes) == len(copied_returning_nodes)
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
gn.replace_all_uses_with(copied_node)
match_changed_node[gn] = copied_node
# Remove the original nodes
for node in reversed(pattern_graph.nodes):
if node.op != "placeholder" and node.op != "output":
gn = match.nodes_map[node]
gm.graph.erase_node(gn)
match_and_replacements.append(
ReplacedPatterns(
anchor=match.anchors[0],
nodes_map=match.nodes_map,
replacements=replacement_nodes
)
)
# Update the passed-in GraphModule to reflect the new state of
# `original_graph`
gm.recompile()
# If `replacement` was an nn.Module, we'll need to make sure that
# all the submodules have been copied over correctly
if isinstance(replacement, torch.nn.Module):
_replace_attributes(gm, replacement)
return match_and_replacements