Skip to content

Commit

Permalink
[export] Change _generate_new_graph_signature (pytorch#108571)
Browse files Browse the repository at this point in the history
Summary:
Previously `_generate_new_graph_signature` had the assumption that all transformations were not in place. However, this is an incorrect assumption leading to mysterious failures when running passes doing in-place modifications.

This function is technically only needed in the case where the user output node or user input node name is changed. For example, if the user output node was "add" but a pass changes all the "add"s to "mul"s, then the output node will now be named "mul", which we have to update.

For cases where users change the number of user inputs/outputs, number of parameters/buffers, or the names of parameters/buffers it will require extra work on the user's side to update the graph signature, since there is no automatic way for us to detect where to put what.

Note: this doesn't actually change the names for the buffers_to_mutate part of the graph signature, but we're going to assume this is rare, because implementing auto-fixing for that is a little hard...

Test Plan: Running `buck test fbcode//mode/dev-nosan fbcode//executorch/backends/xnnpack/test:` on top of D48710125, https://www.internalfb.com/intern/testinfra/testrun/5066549776877081

Differential Revision: D48917505

Pull Request resolved: pytorch#108571
Approved by: https://github.com/zhxchen17
  • Loading branch information
angelayi authored and pytorchmergebot committed Sep 6, 2023
1 parent 089950b commit d856f3b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 121 deletions.
54 changes: 20 additions & 34 deletions test/export/test_pass_infra.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Owner(s): ["module: dynamo"]
from typing import List
import unittest
from typing import List

import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from functorch.experimental import control_flow
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export
from torch._export.pass_base import _ExportPassBase
from torch._export.constraints import constrain_as_value
from functorch.experimental import control_flow
from torch._export.pass_base import _ExportPassBase
from torch.testing._internal.common_utils import run_tests, TestCase


@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
Expand Down Expand Up @@ -105,16 +105,14 @@ def __init__(self):

self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))

self.register_buffer('my_buffer1', torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0))
self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))

def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)

output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
return output

my_module = CustomModule()
Expand All @@ -123,36 +121,24 @@ def forward(self, x1, x2):
input_tensor1 = torch.tensor(5.0)
input_tensor2 = torch.tensor(6.0)

ep_before = export(my_module, (input_tensor1, input_tensor2))

# Dummy pass to modify input names and add new nodes to intentionally
# change output node names
class ModifyInputOutputPass(_ExportPassBase):

def placeholder(self, name, arg, meta):
new_name = name + "_modified"
return super().placeholder(new_name, arg, meta)

def call_operator(self, op, args, kwargs, meta):
ret = super().call_operator(op, args, kwargs, meta)
new_args = (ret,) + args[1:]
new_ret = super().call_operator(op, new_args, kwargs, meta)
return new_ret
ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
from torch.fx.passes.infra.pass_base import PassResult

def modify_input_output_pass(gm):
for node in gm.graph.nodes:
if node.op == "call_function":
node.name = node.name + "_modified"
gm.recompile()
return PassResult(gm, True)

ep_after = ep_before._transform(ModifyInputOutputPass())
ep_after = ep_before._transform(modify_input_output_pass)
new_signature = ep_after.graph_signature

for inp in (
new_signature.user_inputs +
list(new_signature.inputs_to_parameters.keys()) +
list(new_signature.inputs_to_buffers.keys())
):
self.assertTrue("_modified" in inp)
for node_name in new_signature.user_outputs:
self.assertTrue("_modified" in node_name)

old_signature = ep_before.graph_signature
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
self.assertNotEqual(new_signature.buffers_to_mutate.keys(), old_signature.buffers_to_mutate.keys())


if __name__ == '__main__':
Expand Down
137 changes: 50 additions & 87 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,103 +463,66 @@ def get_shape_env(gm):
}
return range_constraints

def get_output_node_names(gm):
output_node = list(gm.graph.nodes)[-1]
assert output_node.op == "output"

return [str(arg) for arg in output_node.args[0]]

def get_input_node_names(gm):
return [node.name for node in gm.graph.nodes if node.op == "placeholder"]

def _generate_new_graph_signature(old_ep, new_gm):
def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
) -> ExportGraphSignature:
"""
Update graph_signature according to graph after transformation.
Transformations can lead to node name changes, which are used in
graph_signature to identify inputs and outputs. Therefore, after each
transformation, we need to update the graph_signature according to
new node names.
WARNING: This implementation makes a few assumptions
- The transformation doesn't change number of inputs/outputs
- Each input/output still has the same meaning.
- For inputs, that means that the inputs in transformed
graph map to the same lifted parameter/buffer or user
input as the input of the same position in the graph
before transformation.
- Similarly for outputs, each output should correspond to the
same mutated buffer or user output as the output value of
the same position in the graph before transformation.
It is difficult to programatically validate these assumptions, but they
should hold true most of the time as inputs/outputs of the graph rarely
need to be changed.
Update the graph signature's user_input/user_outputs.
"""
old_signature = old_ep.graph_signature
old_gm = old_ep.graph_module

old_graph_input_node_names = get_input_node_names(old_gm)
new_graph_input_node_names = get_input_node_names(new_gm)
assert len(old_graph_input_node_names) == len(
new_graph_input_node_names
), f"""
Number of input nodes changed from {len(old_graph_input_node_names)}
to {len(new_graph_input_node_names)} after transformation. This
transformation is currently not supported.
"""

old_graph_output_node_names = get_output_node_names(old_gm)
new_graph_output_node_names = get_output_node_names(new_gm)
assert len(old_graph_output_node_names) == len(
new_graph_output_node_names
), f"""
Number of output values changed from {len(old_graph_output_node_names)}
to {len(new_graph_output_node_names)} after transformation. This
transformation is currently not supported.
"""

node_names_mapping = dict(
zip(
old_graph_input_node_names + old_graph_output_node_names,
new_graph_input_node_names + new_graph_output_node_names,
)
new_graph_inputs = [
node.name for node in new_gm.graph.nodes if node.op == "placeholder"
]
num_inputs = (
len(old_signature.parameters)
+ len(old_signature.buffers)
+ len(old_signature.user_inputs)
)

new_signature = copy.deepcopy(old_signature)
new_signature.user_inputs = [
node_names_mapping[old_user_input]
for old_user_input in old_signature.user_inputs
]
new_signature.user_outputs = [
node_names_mapping[old_user_output]
for old_user_output in old_signature.user_outputs
assert len(new_graph_inputs) == num_inputs, (
f"Number of input nodes changed from {len(new_graph_inputs)} "
f"to {num_inputs} after transformation. This transformation "
"is currently not supported."
)
new_parameter_inputs = new_graph_inputs[: len(old_signature.parameters)]
num_param_buffers = len(old_signature.buffers) + len(
old_signature.parameters
)
new_buffer_inputs = new_graph_inputs[
len(old_signature.parameters) : num_param_buffers
]
new_signature.inputs_to_parameters = {
node_names_mapping[old_input_name]: old_signature.inputs_to_parameters[
old_input_name
]
for old_input_name in old_signature.inputs_to_parameters.keys()
}
new_signature.inputs_to_buffers = {
node_names_mapping[old_input_name]: old_signature.inputs_to_buffers[
old_input_name
]
for old_input_name in old_signature.inputs_to_buffers.keys()
}
new_signature.buffers_to_mutate = {
node_names_mapping[old_output_name]: old_signature.buffers_to_mutate[
old_output_name
]
for old_output_name in old_signature.buffers_to_mutate.keys()
}
return new_signature
new_user_inputs = new_graph_inputs[num_param_buffers:]

new_graph_signature = _generate_new_graph_signature(self, transformed_gm)
output_node = list(new_gm.graph.nodes)[-1]
assert output_node.op == "output"
new_graph_outputs = [arg.name for arg in output_node.args[0]]

assert len(new_graph_outputs) == len(old_signature.buffers_to_mutate) + len(
old_signature.user_outputs
), (
f"Number of output nodes changed from {len(new_graph_outputs)} "
f"to {len(old_signature.buffers_to_mutate) + len(old_signature.user_outputs)} "
"after transformation. This transformation is currently not supported."
)
new_user_outputs = new_graph_outputs[len(old_signature.buffers_to_mutate) :]

new_signature = ExportGraphSignature(
copy.deepcopy(old_signature.parameters),
copy.deepcopy(old_signature.buffers),
new_user_inputs,
new_user_outputs,
copy.deepcopy(old_signature.inputs_to_parameters),
copy.deepcopy(old_signature.inputs_to_buffers),
copy.deepcopy(old_signature.buffers_to_mutate),
copy.deepcopy(old_signature.backward_signature),
copy.deepcopy(old_signature.assertion_dep_token),
)
return new_signature

transformed_ep = ExportedProgram(
transformed_gm,
transformed_gm.graph,
new_graph_signature,
_get_updated_graph_signature(self.graph_signature, transformed_gm),
copy.deepcopy(self.call_spec),
self.state_dict,
_get_updated_range_constraints(transformed_gm),
Expand Down

0 comments on commit d856f3b

Please sign in to comment.