Skip to content

Commit

Permalink
Add torch.export.register_dataclass API (pytorch#109152)
Browse files Browse the repository at this point in the history
`register_dataclass` allows dataclass to be used as valid input/output types of torch.export.export

Pull Request resolved: pytorch#109152
Approved by: https://github.com/ydwu4
  • Loading branch information
gmagogsfm authored and pytorchmergebot committed Sep 13, 2023
1 parent 375d2ca commit a09539f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ API Reference
.. autofunction:: constrain_as_value
.. autofunction:: save
.. autofunction:: load
.. autofunction:: register_dataclass
.. autoclass:: Constraint
.. autoclass:: ExportedProgram

Expand Down
43 changes: 43 additions & 0 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager

from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
UnflattenFunc,
)


__all__ = [
"ArgumentKind",
Expand All @@ -33,6 +40,7 @@
"dynamic_dim",
"export",
"load",
"register_dataclass",
"save",
]

Expand Down Expand Up @@ -1006,6 +1014,7 @@ def specify_constraints(x):
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
``OrderedDict`` containing all above types.
Expand Down Expand Up @@ -1129,3 +1138,37 @@ def load(
return load(
f, extra_files=extra_files, expected_opset_version=expected_opset_version
)


def register_dataclass(typ: Any) -> None:
"""
Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
Args:
typ: the dataclass type to register
Example::
@dataclass
class InputDataClass:
feature: torch.Tensor
bias: int
class OutputDataClass:
res: torch.Tensor
torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)
def fn(o: InputDataClass) -> torch.Tensor:
res = res=o.feature + o.bias
return OutputDataClass(res=res)
ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
"""

from torch._export.utils import register_dataclass_as_pytree_node

return register_dataclass_as_pytree_node(typ)
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.export.dynamic_dim,
torch.export.export,
torch.export.load,
torch.export.register_dataclass,
torch.export.save,
torch.eye,
torch.fft.fftfreq,
Expand Down

0 comments on commit a09539f

Please sign in to comment.