Skip to content

Commit

Permalink
Back out "Revert D49107540: [pytorch][PR] split by tag" (pytorch#109332)
Browse files Browse the repository at this point in the history
Summary:
Original commit changeset: 6391a068640b

Original Phabricator Diff: D49107540

Test Plan: same as D49107540

Differential Revision: D49297522

Pull Request resolved: pytorch#109332
Approved by: https://github.com/842974287
  • Loading branch information
Wenting Wang authored and pytorchmergebot committed Sep 16, 2023
1 parent 7bce7f5 commit 393fe93
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 19 deletions.
115 changes: 115 additions & 0 deletions test/fx/test_fx_split.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Owner(s): ["module: fx"]

from collections import defaultdict
from typing import List, Tuple, Dict

import torch
from torch.fx.passes.split_utils import split_by_tags

Expand Down Expand Up @@ -30,3 +33,115 @@ def forward(self, x, y):
if n.op != "output":
self.assertIn("name", n.meta)
self.assertEqual(n.meta["name"], n.name)


class TestSplitByTags(TestCase):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(2, 3)
self.linear2 = torch.nn.Linear(4, 5)
self.linear3 = torch.nn.Linear(6, 7)
self.linear4 = torch.nn.Linear(8, 6)

def forward(
self,
x1: torch.Tensor,
x2: torch.Tensor,
x3: torch.Tensor,
) -> torch.Tensor:
v1 = self.linear1(x1)
v2 = self.linear2(x2)
v3 = self.linear3(x3)
v4 = torch.cat([v1, v2, v3])
return self.linear4(v4)

@staticmethod
def trace_and_tag(
module: torch.nn.Module, tags: List[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
linear1 - tag: "red"
linear2 - tag: "blue"
linear3, linear4 - tag: "green"
At the beginning we have:
gm:
linear1
linear2
linear3
linear4
split_gm = split_by_tags(gm, tags)
Then we have:
split_gm:
red:
linear1
blue:
linear2
green:
linear3
linear4
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module)

# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "linear1" in node.name:
node.tag = tags[0]
tag_node[tags[0]].append(node.name)
elif "linear2" in node.name:
node.tag = tags[1]
tag_node[tags[1]].append(node.name)
else:
node.tag = tags[2]
if node.op == "call_module":
tag_node[tags[2]].append(node.name)
return gm, tag_node

def test_split_by_tags(self) -> None:
tags = ["red", "blue", "green"]
module = TestSplitByTags.TestModule()
gm, tag_node = TestSplitByTags.trace_and_tag(module, tags)
split_gm, orig_to_split_fqn_mapping = split_by_tags(
gm, tags, return_fqn_mapping=True
)
# Ensure split_gm has (and only has) ordered submodules named
# red_0, blue_1, green_2
for idx, (name, _) in enumerate(split_gm.named_children()):
if idx < len(tags):
self.assertTrue(
name == tags[idx],
f"split_gm has an incorrect submodule named {name}",
)

# Ensure each submodule has expected (ordered) call_module node(s).
# For example, a submodule named split_gm.red_0 has (and only has) linear1;
# split_gm.green_2 has (and only has) linear3 and linear4 with order
sub_graph_idx = 0
for sub_name, sub_graph_module in split_gm.named_children():
node_idx = 0
for node in sub_graph_module.graph.nodes:
if node.op != "call_module":
continue
self.assertTrue(
node.name == tag_node[f"{sub_name}"][node_idx],
# pyre-fixme[61]: `name` is undefined, or not always defined.
f"{sub_name} has incorrectly include {node.name}",
)
node_idx += 1
sub_graph_idx += 1

self.assertEqual(
orig_to_split_fqn_mapping,
{
"linear1": "red.linear1",
"linear2": "blue.linear2",
"linear3": "green.linear3",
"linear4": "green.linear4",
},
f"{orig_to_split_fqn_mapping=}",
)
36 changes: 25 additions & 11 deletions torch/fx/passes/split_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

import torch.fx
from torch.fx._compatibility import compatibility
Expand All @@ -11,6 +11,7 @@

__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]


@compatibility(is_backward_compatible=False)
def getattr_recursive(obj, name):
for layer in name.split("."):
Expand Down Expand Up @@ -57,11 +58,13 @@ class Component:


@compatibility(is_backward_compatible=False)
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule:
def split_by_tags(
gm: torch.fx.GraphModule, tags: List[str], return_fqn_mapping: bool = False
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
"""
Splits a GraphModule using tags on its graph nodes. We honor the order of
tags. For example, we have tags = ["a", "b", "c"], the function will create
the initial submodules in the order of "a_0", "b_1", "c_2".
the initial submodules in the order of "a", "b", "c".
To set a tag:
gm.graph.nodes[idx].tag = "mytag"
Expand All @@ -88,26 +91,31 @@ def forward(self, in1, in2):
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
ro_0:
ro:
def forward(self, in1):
self = self.root
linear1 = self.linear1(in1)
return linear1
main_1:
main:
def forward(self, in2, linear1):
self = self.root
linear2 = self.linear2(in2)
cat_1 = torch.cat([linear1, linear2])
linear3 = self.linear3(cat_1)
return linear3
main_0:
main:
def forward(self, in1, in2):
self = self.root
ro_0 = self.ro_0(in1)
main_1 = self.main_1(in2, ro_0)
return main_1
Returns:
split_gm: torch fx graph after split
orig_to_split_fqn_mapping: a map between the original fqn and the fqn
after split for call_module and get_attr.
"""

def flatten(x: torch.fx.node.Argument) -> NodeList:
Expand Down Expand Up @@ -210,9 +218,7 @@ def remap_func(x):
comp.orig_inputs.append(x)
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
placeholder.meta = copy.copy(x.meta)
comp.input_placeholders.append(
placeholder
)
comp.input_placeholders.append(placeholder)
used_in_main[x] = None

return comp.input_placeholders[comp.orig_inputs.index(x)]
Expand Down Expand Up @@ -243,6 +249,7 @@ def remap_func(x):
node_to_component[n].orig_outputs.append(n)

# Now we create a graphmodule for each component.
orig_to_split_fqn_mapping: Dict[str, str] = {}
for comp in all_components:
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))

Expand All @@ -252,7 +259,10 @@ def remap_func(x):
# ((output_0, output_1, ...)).
comp.graph.output(outs[0] if len(outs) == 1 else outs)

comp.gm = lift_subgraph_as_module(gm, comp.graph)
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
gm, subgraph=comp.graph, comp_name=comp.name
)
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)

# Create a call_module node in main graph.
main_node = main_g.call_module(
Expand All @@ -277,4 +287,8 @@ def remap_func(x):
if x.op == "get_attr":
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]

return torch.fx.GraphModule(main_root, main_g)
result_gm = torch.fx.GraphModule(main_root, main_g)
if return_fqn_mapping:
return result_gm, orig_to_split_fqn_mapping

return result_gm
24 changes: 18 additions & 6 deletions torch/fx/passes/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from torch.nn import Module
from typing import Dict, Tuple

from torch.fx.graph_module import GraphModule
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph

from torch.fx.graph_module import GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.fx._compatibility import compatibility
from torch.nn import Module


__all__ = ['HolderModule', 'lift_subgraph_as_module', 'compare_graphs']
__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]


@compatibility(is_backward_compatible=False)
class HolderModule(Module):
Expand All @@ -22,7 +25,12 @@ def __init__(self, d):


@compatibility(is_backward_compatible=False)
def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str = 'GraphModule') -> GraphModule:
def lift_subgraph_as_module(
gm: GraphModule,
subgraph: Graph,
comp_name: str = "",
class_name: str = "GraphModule",
) -> Tuple[GraphModule, Dict[str, str]]:
"""
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
Expand All @@ -31,6 +39,8 @@ def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str =
subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
comp_name (str): name for the new component
class_name (str): name for the submodule
"""
Expand All @@ -42,6 +52,7 @@ def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str =
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
# the original module.
submodule = HolderModule({})
orig_to_split_fqn_mapping: Dict[str, str] = {}
for n in subgraph.nodes:
if n.op not in ("call_module", "get_attr"):
continue
Expand All @@ -62,10 +73,11 @@ def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str =
leaf_node_name = target_name_parts[-1]
leaf_node = getattr(orig_gm, leaf_node_name)

orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
# Relies on custom __setattr__ magic.
setattr(curr, leaf_node_name, leaf_node)

return GraphModule(submodule, subgraph, class_name)
return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping


@compatibility(is_backward_compatible=False)
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/passes/utils/fuser_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def remap_inputs(x):

# lint to ensure correctness
subgraph.lint()

fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name)
fused_gm: GraphModule
fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name)

# sub_gm's input nodes in the original module
original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
Expand Down

0 comments on commit 393fe93

Please sign in to comment.