Skip to content

Commit

Permalink
[export] Don't save example_inputs for now. (pytorch#107978)
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#107978
Approved by: https://github.com/angelayi
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Aug 26, 2023
1 parent d4a9963 commit 162109f
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 22 deletions.
7 changes: 0 additions & 7 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,6 @@ def check_graph(self, fn, inputs, constraints=None, _check_meta=True) -> None:
else:
self.assertEqual(orig, loaded)

self.assertEqual(len(ep.original_traced_arguments), len(deserialized_ep.original_traced_arguments))
for arg1, arg2 in zip(ep.original_traced_arguments, deserialized_ep.original_traced_arguments):
if isinstance(arg1, torch.Tensor) and isinstance(arg2, torch.Tensor):
self.assertTrue(torch.allclose(arg1, arg2))
else:
self.assertEqual(type(arg1), type(arg2))

def _check_graph_nodes(gm1, gm2, _check_meta=True):
# TODO: The _check_meta flag bypasses checking for
# source_fn/nn_module_stack as there is an issue with
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def to_str_dict(sig_component: Dict[Any, Any]):
range_constraints,
equality_constraints,
[ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()],
args,
(args, {}),
)

exported_program = exported_program._transform(
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,4 @@ class ExportedProgram:
range_constraints: Dict[str, RangeConstraint]
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]
schema_version: int
original_traced_arguments: str
example_inputs: Optional[Tuple[List[bytes], Dict[str, bytes]]]
10 changes: 2 additions & 8 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,9 +809,6 @@ def serialize(self, exported_program: ep.ExportedProgram) -> Tuple[ExportedProgr
)
serialized_range_constraints = serialize_range_constraints(exported_program.range_constraints)
serialized_equality_constraints = serialize_equality_constraints(exported_program.equality_constraints)
serialized_original_arguments = base64.b64encode(
serialize_torch_artifact(exported_program.original_traced_arguments)
).decode('utf-8')

return (
ExportedProgram(
Expand All @@ -820,7 +817,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> Tuple[ExportedProgr
range_constraints=serialized_range_constraints,
equality_constraints=serialized_equality_constraints,
schema_version=SCHEMA_VERSION,
original_traced_arguments=serialized_original_arguments,
example_inputs=None,
),
serialize_torch_artifact(exported_program.state_dict),
)
Expand Down Expand Up @@ -1367,9 +1364,6 @@ def deserialize(

state_dict = deserialize_torch_artifact(serialized_state_dict)
equality_constraints = deserialize_equality_constraints(serialized_exported_program.equality_constraints)
original_traced_arguments = deserialize_torch_artifact(
base64.b64decode(serialized_exported_program.original_traced_arguments)
)

exported_program = ep.ExportedProgram(
graph_module,
Expand All @@ -1380,7 +1374,7 @@ def deserialize(
range_constraints,
equality_constraints,
module_call_graph,
original_traced_arguments, # type: ignore[arg-type]
None, # type: ignore[arg-type]
)
return upgrader.upgrade(exported_program)

Expand Down
10 changes: 5 additions & 5 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
range_constraints: Dict[sympy.Symbol, Any],
equality_constraints: List[Tuple[Any, Any]],
module_call_graph: List[ModuleCallEntry],
original_traced_arguments: Tuple[Any, ...] = (),
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
):
from torch._export.exported_program import (
_create_graph_module_for_export,
Expand All @@ -264,7 +264,7 @@ def __init__(
Tuple[InputDim, InputDim]
] = equality_constraints
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
self._original_traced_arguments = original_traced_arguments
self._example_inputs = example_inputs

@property
@compatibility(is_backward_compatible=False)
Expand Down Expand Up @@ -308,8 +308,8 @@ def module_call_graph(self):

@property
@compatibility(is_backward_compatible=False)
def original_traced_arguments(self):
return self._original_traced_arguments
def example_inputs(self):
return self._example_inputs

def __call__(self, *args: Any, **kwargs: Any) -> Any:
import torch._export.error as error
Expand Down Expand Up @@ -531,7 +531,7 @@ def _generate_new_graph_signature(old_ep, new_gm):
_get_updated_range_constraints(transformed_gm),
copy.deepcopy(self.equality_constraints),
copy.deepcopy(self._module_call_graph),
self.original_traced_arguments,
self.example_inputs,
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
Expand Down

0 comments on commit 162109f

Please sign in to comment.