Skip to content

Commit

Permalink
[dynamo] preserve some FX node metadata of GraphModules (pytorch#107067)
Browse files Browse the repository at this point in the history
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: pytorch#107067
Approved by: https://github.com/jansel
  • Loading branch information
williamwen42 authored and pytorchmergebot committed Sep 15, 2023
1 parent 7af792a commit b904432
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 28 deletions.
125 changes: 125 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_export_persist_assert)
"""
import copy
import functools
import inspect
import math
Expand Down Expand Up @@ -3954,6 +3955,130 @@ def forward(self, x, y):
if node.op == "call_function":
self.assertIn("nn_module_stack", node.meta)

def test_preserve_fx_node_metadata(self):
class Module1(torch.nn.Module):
def forward(self, x):
return torch.sin(x)

class Module2(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod1 = Module1()

def forward(self, x):
x = torch.cos(x)
x = self.mod1(x)
x = torch.relu(x)
return x

def fn(x):
return torch.abs(x)

mod = Module2()
inp = torch.randn(3, 3)

gm, _ = torch._dynamo.export(mod)(inp)

# replace relu with fn
gm_edit = copy.deepcopy(gm)
for nd in gm_edit.graph.nodes:
if nd.target == torch.relu:
nd.target = fn
nd.meta.clear()
break
gm_edit.recompile()

gm2, _ = torch._dynamo.export(gm_edit)(inp)

# check for source code
gm_code = gm.print_readable(print_output=False)
gm_edit_code = gm_edit.print_readable(print_output=False)
gm2_code = gm2.print_readable(print_output=False)
for code in (gm_code, gm_edit_code, gm2_code):
self.assertIn("x = torch.cos(x)", code)
self.assertIn("return torch.sin(x)", code)
self.assertIn("x = torch.relu(x)", gm_code)
self.assertNotIn("x = torch.relu(x)", gm_edit_code)
self.assertNotIn("x = torch.relu(x)", gm2_code)
self.assertIn("return torch.abs(x)", gm2_code)

# check for other metadata
for op in (torch.sin, torch.cos):
nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes))
nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes))
self.assertTrue(
("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta)
)
if "nn_module_stack" in nd1.meta:
self.assertEqual(
nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
)
self.assertEqual(nd1.meta["source_fn"], nd2.meta["source_fn"])
self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])

def test_preserve_fx_node_metadata_recompile(self):
def fn(x):
return torch.sin(x)

gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
do_export = torch._dynamo.export(gm)
torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3))
gm1, _ = do_export(torch.randn(3, 3))
gm2, _ = do_export(torch.randn(5, 3))

self.assertIn("return torch.sin(x)", gm1.print_readable(print_output=False))
self.assertIn("return torch.sin(x)", gm2.print_readable(print_output=False))

def test_preserve_fx_node_metadata_inline(self):
def f1(x):
return torch.sin(x)

gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3))

def f2(x):
x = torch.cos(x)
return gm(x)

gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3))

self.assertIn("return torch.sin(x)", gm2.print_readable(print_output=False))

def test_preserve_fx_node_metadata_graph_break(self):
def fn(x):
x = torch.sin(x)
x = torch.abs(x)
return torch.cos(x)

def bad_fn(x):
torch._dynamo.graph_break()
return x

gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))

# replace abs with graph break
gm_edit = copy.deepcopy(gm)
for nd in gm_edit.graph.nodes:
if nd.target == torch.abs:
nd.target = bad_fn
nd.meta.clear()
break
gm_edit.recompile()

expected = [
"x = torch.sin(x)",
"return torch.cos(x)",
]

def test_backend(gm: torch.fx.GraphModule, example_inputs):
self.assertTrue(expected)
self.assertIn(expected[0], gm.print_readable(print_output=False))
expected.pop(0)
return gm.forward

torch._dynamo.reset()
opt_gm_edit = torch.compile(gm_edit, backend=test_backend)
opt_gm_edit(torch.randn(3, 3))


common_utils.instantiate_parametrized_tests(ExportTests)

Expand Down
21 changes: 21 additions & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,27 @@ def forward(self, a, b):
self.assertTrue(node.stack_trace is not None)
assert 'test_fx.py' in node.stack_trace

def test_lineno_map(self):
class M(torch.nn.Module):
def forward(self, a, b):
a = torch.sin(a)
b = torch.cos(b)
return a + b

tracer = torch.fx.Tracer()
graph = tracer.trace(M())
gm = GraphModule(tracer.root, graph)
expected = {1: 2, 2: 3, 3: 4, 4: 5}
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))

# test custom codegen
def transform_code(code):
return ["print('hello!')\n", *code]
gm.graph.on_generate_code(lambda _: transform_code)
gm.recompile()
expected = {2: 2, 3: 3, 4: 4, 5: 5}
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))

def test_graph_unique_names_manual(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.create_node('placeholder', 'x')
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from . import allowed_functions, convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, register_backend
from .code_context import code_context
from .convert_frame import replay
from .decorators import (
allow_in_graph,
Expand Down Expand Up @@ -68,3 +69,4 @@ def reset() -> None:
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()
code_context.clear()
29 changes: 29 additions & 0 deletions torch/_dynamo/code_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import types

from .utils import ExactWeakKeyDictionary


class CodeContextDict:
def __init__(self):
self.code_context = ExactWeakKeyDictionary()

def has_context(self, code: types.CodeType):
return code in self.code_context

def get_context(self, code: types.CodeType):
ctx = self.code_context.get(code)
if ctx is None:
ctx = {}
self.code_context[code] = ctx
return ctx

def pop_context(self, code: types.CodeType):
ctx = self.get_context(code)
self.code_context._remove_id(id(code))
return ctx

def clear(self):
self.code_context.clear()


code_context = CodeContextDict()
9 changes: 9 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)

from . import config, convert_frame, external_utils, skipfiles, utils
from .code_context import code_context
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
from .mutation_guard import install_generation_tagging_init
from .types import DynamoCallback
Expand Down Expand Up @@ -337,6 +338,14 @@ def get_compiler_config():
return self.compiler_config

fn = innermost_fn(fn)

# add context containing GraphModule to any GraphModule forward functions
if isinstance(fn, torch.fx.GraphModule):
# Assume that the underlying node metadata of `fn`,
# a GraphModule instance, accurately represents
# all instances of type(fn).
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = fn

# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):
mod = fn
Expand Down
91 changes: 69 additions & 22 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Instruction,
unique_id,
)
from .code_context import code_context
from .codegen import PyCodegen
from .current_scope_id import enter_new_scope
from .exc import (
Expand Down Expand Up @@ -1215,6 +1216,11 @@ def __init__(self, output_graph, parent=None, export_root=False):
self.lifted_freevars = collections.OrderedDict()
self.prev_inst = None

self._cur_code = None
self._orig_gm_meta = None
self._orig_gm_lineno_map = None
self._orig_gm_firstlineno = None

def create_proxy(
self,
kind,
Expand Down Expand Up @@ -1292,31 +1298,72 @@ def get_trace_call_log_str():
trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
self.prev_inst = cur_inst

nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

if kind in {"call_function", "call_method"}:
rv.node.meta["source_fn"] = (rv.node.name, target)
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# For modules we store the class
rv.node.meta["source_fn"] = (
rv.node.name,
rv.node.meta["nn_module_stack"][target][1],
# update reference to original meta if we're tracing a new code object
if tx.f_code is not self._cur_code:
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
"orig_graphmodule", None
)
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
self._orig_gm_meta = [
nd.meta for nd in orig_graphmodule_maybe.graph.nodes
]
self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
self._orig_gm_firstlineno = (
orig_graphmodule_maybe.forward.__code__.co_firstlineno
)
else:
self._orig_gm_meta = None
self._orig_gm_lineno_map = None
self._orig_gm_firstlineno = None

frame_summaries: List[traceback.FrameSummary] = []
while tx:
frame_summaries.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
# Reverse the frame_summaries, such that the innermost frame is at the last
frame_summaries.reverse()
# preserve original meta if it is available
if (
self._orig_gm_meta
and self._orig_gm_lineno_map
and self._orig_gm_firstlineno
):
lineno = tx.current_instruction.starts_line
node_idx = None
if lineno is not None:
node_idx = self._orig_gm_lineno_map.get(
lineno - self._orig_gm_firstlineno, None
)
if node_idx is not None:
meta = self._orig_gm_meta[node_idx]
if "stack_trace" in meta:
rv.node.meta["stack_trace"] = meta["stack_trace"]
if "nn_module_stack" in meta and "source_fn" in meta:
rv.node.meta["nn_module_stack"] = meta["nn_module_stack"]
rv.node.meta["source_fn"] = meta["source_fn"]

if "nn_module_stack" not in rv.node.meta:
nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()

if "source_fn" not in rv.node.meta:
if kind in {"call_function", "call_method"}:
rv.node.meta["source_fn"] = (rv.node.name, target)
elif kind == "call_module":
if self.parent is not None:
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
# For modules we store the class
rv.node.meta["source_fn"] = (
rv.node.name,
rv.node.meta["nn_module_stack"][target][1],
)

# official from_list stub doesn't have new-style type
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
rv.node.stack_trace = "".join(msgs)
if "stack_trace" not in rv.node.meta:
frame_summaries: List[traceback.FrameSummary] = []
while tx:
frame_summaries.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
# Reverse the frame_summaries, such that the innermost frame is at the last
frame_summaries.reverse()

# official from_list stub doesn't have new-style type
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
rv.node.stack_trace = "".join(msgs)

return rv

Expand Down
Loading

0 comments on commit b904432

Please sign in to comment.