forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[dynamo] preserve some FX node metadata of GraphModules (pytorch#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work. This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times. The added unit test demonstrates the added functionality of this PR. ~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following: - ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~ - Construct a line number -> node index map in `GraphModule` as the source code is being generated. - pass a list of node metadata and the line number map to dynamo's bytecode analyzer - ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~ - When we create a new proxy, get the current instruction's line number, and get the node index using the line number map - index into the original node metadata ~using the counter variable's tracked value.~ ~Some things that should be addressed off the top of my head:~ - ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~ - ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~ - ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~ Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108) Pull Request resolved: pytorch#107067 Approved by: https://github.com/jansel
- Loading branch information
1 parent
7af792a
commit b904432
Showing
9 changed files
with
314 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import types | ||
|
||
from .utils import ExactWeakKeyDictionary | ||
|
||
|
||
class CodeContextDict: | ||
def __init__(self): | ||
self.code_context = ExactWeakKeyDictionary() | ||
|
||
def has_context(self, code: types.CodeType): | ||
return code in self.code_context | ||
|
||
def get_context(self, code: types.CodeType): | ||
ctx = self.code_context.get(code) | ||
if ctx is None: | ||
ctx = {} | ||
self.code_context[code] = ctx | ||
return ctx | ||
|
||
def pop_context(self, code: types.CodeType): | ||
ctx = self.get_context(code) | ||
self.code_context._remove_id(id(code)) | ||
return ctx | ||
|
||
def clear(self): | ||
self.code_context.clear() | ||
|
||
|
||
code_context = CodeContextDict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.