Skip to content

Commit

Permalink
[inductor] visualize fused ops in svg graph (pytorch#107752)
Browse files Browse the repository at this point in the history
example usage
* `TORCH_COMPILE_DEBUG=1 INDUCTOR_ORIG_FX_SVG=1 INDUCTOR_POST_FUSION_SVG=1 python trig.py`: show original fx node name, file, and code. see snapshot 2 where we have origin_0, 1, 2
* trig.py can be found in P816304818

Implementation
* keep original fx graph in GraphLowering, ```self.orig_gm: torch.fx.GraphModule = gm.__copy__()```
* draw original fx graph with origins ir_post_fusion ```V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)```. node.meta["buff_meta"] tracks buf_name

<img width="350" alt="Screenshot 2023-08-29 at 12 40 24 PM" src="https://github.com/pytorch/pytorch/assets/134637289/c4e197cb-ab3b-4a09-a584-c1356376accb">

Pull Request resolved: pytorch#107752
Approved by: https://github.com/mlazos
  • Loading branch information
weifengpy authored and pytorchmergebot committed Sep 21, 2023
1 parent f5b753b commit 772e104
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 27 deletions.
18 changes: 15 additions & 3 deletions torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,14 @@ def get_node_weight(node) -> int:
return fw_module, bw_module


def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True):
def draw_graph(
traced: torch.fx.GraphModule,
fname: str,
figname: str = "fx_graph",
clear_meta: bool = True,
prog: str = None,
parse_stack_trace: bool = False,
) -> None:
if clear_meta:
new_graph = copy.deepcopy(traced.graph)
traced = fx.GraphModule(traced, new_graph)
Expand All @@ -917,9 +924,14 @@ def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_grap
if not ext:
ext = ".svg"
print(f"Writing FX graph to file: {base}{ext}")
g = graph_drawer.FxGraphDrawer(traced, figname)
g = graph_drawer.FxGraphDrawer(traced, figname, parse_stack_trace=parse_stack_trace)
x = g.get_main_dot_graph()
getattr(x, "write_" + ext.lstrip("."))(f"{base}{ext}")
write_method = getattr(x, "write_" + ext.lstrip("."))
fname = f"{base}{ext}"
if prog is None:
write_method(fname)
else:
write_method(fname, prog=prog)


def draw_joint_graph(graph, joint_inputs, file_name="full_graph.png"):
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,9 @@ class trace:
# SVG figure showing post-fusion graph
graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"

# SVG figure showing fx with fusion
draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"

# Store cProfile (see snakeviz to view)
compile_profile = False

Expand Down
85 changes: 81 additions & 4 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pstats
import shutil
import subprocess
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional
from unittest.mock import patch

from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
Expand All @@ -38,6 +38,10 @@

log = logging.getLogger(__name__)

SchedulerNodeList = List[Any]
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]


@functools.lru_cache(None)
def has_dot() -> bool:
Expand Down Expand Up @@ -174,6 +178,72 @@ def in_output(snode):
return graph


def update_orig_fx_node_name_to_buf_name(
nodes: SchedulerNodeList,
node_name_to_buf_name: Dict[str, str],
parent_buf_name: Optional[str] = None,
n_origins: int = 0,
):
if nodes is None:
return
for node in nodes:
# for FusedSchedulerNode, traverse recursively into get_nodes()
buf_name = node.get_name()
children_nodes = node.get_nodes()
if children_nodes is not None and len(children_nodes) > 1:
update_orig_fx_node_name_to_buf_name(
children_nodes,
node_name_to_buf_name,
buf_name if parent_buf_name is None else parent_buf_name,
)
continue
else:
assert len(children_nodes) == 1 and children_nodes[0] == node

ir_node = node.node
if ir_node is None or ir_node.origins is None:
continue
for origin in ir_node.origins:
node_name = origin.name
# when buf1 and buf2 both have origin=node1
# we draw node1 according to buf1
if node_name not in node_name_to_buf_name:
node_name_to_buf_name[node_name] = (
buf_name if parent_buf_name is None else parent_buf_name
)


def get_node_name_to_buf_meta(node_name_to_buf_name: Dict[str, str]):
buf_name_to_n_node = {}
for node_name, buf_name in node_name_to_buf_name.items():
if buf_name not in buf_name_to_n_node:
buf_name_to_n_node[buf_name] = {node_name}
else:
buf_name_to_n_node[buf_name].add(node_name)

node_name_to_buf_meta = {}
for node_name, buf_name in node_name_to_buf_name.items():
n_node = len(buf_name_to_n_node[buf_name])
node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
return node_name_to_buf_meta


def annotate_orig_fx_with_snodes(
gm: torch.fx.GraphModule, snodes: SchedulerNodeList
) -> None:
"""
Creates a FX Graph from a list of SchedulerNode objects.
"""
node_name_to_buf_name: Dict[str, str] = {}
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
if node_name_to_buf_name is None:
return
node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
for node in gm.graph.nodes:
if node.name in node_name_to_buf_meta:
node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)


@contextlib.contextmanager
def enable_aot_logging():
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
Expand Down Expand Up @@ -354,9 +424,6 @@ def ignored(*args, **kwargs):
return ignored


SchedulerNodeList = List[Any]


class DebugFormatter:
def __init__(self, handler):
self.fopen = handler.fopen
Expand Down Expand Up @@ -391,6 +458,16 @@ def _write_ir(self, filename: str, nodes: SchedulerNodeList):
def graph_diagram(self, nodes: SchedulerNodeList):
draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))

def draw_orig_fx_graph(self, gm: torch.fx.GraphModule, nodes: SchedulerNodeList):
annotate_orig_fx_with_snodes(gm, nodes)
draw_graph(
gm,
fname=self.filename("orig_fx_graph_diagram.svg"),
clear_meta=False,
prog=GRAPHVIZ_COMMAND_SCALABLE,
parse_stack_trace=True,
)

def output_code(self, filename):
shutil.copy(filename, self.filename("output_code.py"))

Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
# Used if lowering encounters cases where cudagraphs are not supported
self.disable_cudagraphs = False
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
self.init_backend_registration()

@staticmethod
Expand Down Expand Up @@ -922,6 +923,7 @@ def codegen(self):

self.scheduler = Scheduler(self.buffers)
assert self.scheduler is not None # mypy can't figure this out
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
self.scheduler.codegen()
assert self.wrapper_code is not None
return self.wrapper_code.generate()
Expand Down
48 changes: 32 additions & 16 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
from collections import defaultdict
from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -283,6 +284,29 @@ class _PyTreeInfo(NamedTuple):
in_spec: pytree.TreeSpec
out_spec: Optional[pytree.TreeSpec]

# get File:lineno code from stack_trace
def _parse_stack_trace(stack_trace: str):
if stack_trace is None:
return None
ParsedStackTrace = collections.namedtuple("ParsedStackTrace", ["file", "lineno", "code"])
pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
lines = stack_trace.strip().split('\n')
# stacktrace should have innermost frame last, so we
# iterate backwards to find the first line that starts
# with 'File '
summary_str = ""
for idx in range(len(lines) - 2, -1, -1):
line = lines[idx].strip()
matches = pattern.match(line)
if matches:
file = matches.group(1)
lineno = matches.group(2)
# next line should be the code
code = lines[idx + 1].strip()
return ParsedStackTrace(file, lineno, code)
return None


@compatibility(is_backward_compatible=False)
class CodeGen:
def __init__(self):
Expand Down Expand Up @@ -465,28 +489,20 @@ def append_stacktrace_summary(node : Node):
useful for debugging.
"""
nonlocal prev_stacktrace
pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")

if node.op not in {'placeholder', 'output'}:
if node.stack_trace:
if node.stack_trace != prev_stacktrace:
prev_stacktrace = node.stack_trace

lines = node.stack_trace.strip().split('\n')
# stacktrace should have innermost frame last, so we
# iterate backwards to find the first line that starts
# with 'File '
summary_str = ""
for idx in range(len(lines) - 2, -1, -1):
line = lines[idx].strip()
matches = pattern.match(line)
if matches:
file = matches.group(1)
lineno = matches.group(2)
# next line should be the code
code = lines[idx + 1].strip()
summary_str = f'File: {file}:{lineno}, code: {code}'
break

parsed_stack_trace = _parse_stack_trace(node.stack_trace)

if parsed_stack_trace is not None:
lineno = parsed_stack_trace.lineno
code = parsed_stack_trace.code
summary_str = f'File: {parsed_stack_trace.file}:{lineno}, code: {code}'

body.append(f'\n# {summary_str}\n')
elif prev_stacktrace != "":
prev_stacktrace = ""
Expand Down
Loading

0 comments on commit 772e104

Please sign in to comment.