Skip to content

Commit

Permalink
Expose torch.export.dynamic_dim() API (pytorch#107635)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmagogsfm authored and pytorchmergebot committed Aug 22, 2023
1 parent 515aa99 commit 137d96a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ torch.export

.. automodule:: torch.export
.. autofunction:: export
.. autofunction:: dynamic_dim

.. toctree::
:glob:
Expand Down
35 changes: 1 addition & 34 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,40 +59,7 @@
)
from .wrappers import _wrap_submodules

# Note - [On Export Dynamic Dimension UX]
#
# After a lot of discussion, we have settled on a dynamic marking API
# for export that meets the following constraints:
# 1) Stateless
# 2) Safe for numerous .export calls within a single process
# 3) Simple to use
# 4) Can be extended to constraints easily
#
# While the underlying API is still torch._dynamo.mark_dynamic, we offer a higher
# level API that meets the constraints above.
#
# This API produces an object that is meant to be passed into torch._dynamo.export
# constraints field. See docs on torch._dynamo.export for more details.
#
# Note - The output type and structure here is NOT BC and NOT A CONTRACT, we reserve
# the right to change the output here at any time, and will do so as we extend the API.
#
# result = torch._dynamo.export(
# my_model,
# constraints=[
# # if you do only dynamic_dim, this is sugar for
# # -Inf <= dynamic_dim(blah, 0) <= Inf; we don’t otherwise
# # permit direct int->bool conversion
# dynamic_dim(blah, 0),
# # operator overloading because it makes it clear whether
# # or not you’re inclusive-exclusive range or not
# 0 <= dynamic_dim(blah, 1) <= 100,
# # NB: But we actually truncate ranges to be >= 2, because of
# # 0/1 specialization
# ]
# )(
# *sixtyfour_tensors,
# )

def dynamic_dim(t: torch.Tensor, index: int):
if not isinstance(t, torch.Tensor):
raise UserError(
Expand Down
82 changes: 77 additions & 5 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,82 @@


__all__ = [
"dynamic_dim",
"export",
]


def dynamic_dim(t: torch.Tensor, index: int):
"""
`dynamic_dim` constructs a `Constraint` object that describes the dynamism of
a dimension `index` of tensor `t`. `Constraint` objects should be passed to
`constraints` argument of `export()`.
Specifically `dynamic_dim` can be used to express following types of dynamism.
- Size of a dimension is dynamic and unbounded::
t0 = torch.rand(2, 3)
t1 = torch.rand(3, 4)
# First dimension of t0 can be dynamic size rather than always being static size 2
constraints = [dynamic_dim(t0, 0)]
ep = export(fn, (t0, t1), constraints=constraints)
- Size of a dimension is dynamic with a lower bound::
t0 = torch.rand(10, 3)
t1 = torch.rand(3, 4)
# First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
# Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
constraints = [
dynamic_dim(t0, 0) >= 5,
dynamic_dim(t1, 1) > 2,
]
ep = export(fn, (t0, t1), constraints=constraints)
- Size of a dimension is dynamic with an upper bound::
t0 = torch.rand(10, 3)
t1 = torch.rand(3, 4)
# First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
# Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
constraints = [
dynamic_dim(t0, 0) <= 16,
dynamic_dim(t1, 1) < 8,
]
ep = export(fn, (t0, t1), constraints=constraints)
- Size of a dimension is dynamic and it is always equal to size of another dynamic dimension::
t0 = torch.rand(10, 3)
t1 = torch.rand(3, 4)
# Sizes of second dimension of t0 and first dimension are always equal
constraints = [
dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
]
ep = export(fn, (t0, t1), constraints=constraints)
- Mix and match all types above as long as they do not express conflicting requirements
Args:
t (torch.Tensor): Example input tensor that have dynamic dimension size(s)
index (int): Index of dynamic dimension
Returns:
A `Constraint` object that describes shape dynamism. It can be passed to `export()` so
that `export()` does not assume static size of specified tensor, i.e. keeping it dynamic
as a symbolic size rather than specializing according to size of example tracing input.
"""
from torch._export import dynamic_dim

return dynamic_dim(t, index)


def export(
f: Callable,
args: Tuple[Any],
Expand Down Expand Up @@ -97,7 +169,7 @@ def fn(x):
Note:
If you want to preserve dynamic branching behavior based on value or
shape of torch.Tensor in the generated graph, you will need to use
`torch._export.dynamic_dim` to make a dimension of input tensor to be dynamic
`torch.export.dynamic_dim` to make a dimension of input tensor to be dynamic
and rewrite the source code using control flow operations like
`torch.ops.higher_order.cond`.
Expand All @@ -114,7 +186,7 @@ def fn(x):
Because static shape use cases are more dominant, `export()` chooses to
assume shapes are all static by default unless there are explicit user
instructions that say otherwise. Specifically, users can use
`torch._export.dynamic_dim` to give a hint to `export()` about dynamism
`torch.export.dynamic_dim` to give a hint to `export()` about dynamism
and range of an input tensor dimension.
2. Dynamic Control Flow
Expand Down Expand Up @@ -142,7 +214,7 @@ def fn(x):
- Assumptions on static shapes of input tensors are automatically validated without additional effort.
- Assumptions on dynamic shape of input tensors require explicit `Input Constraint`
constructed with `torch._export.dynamic_dim` APIs
constructed with `torch.export.dynamic_dim` APIs
- Assumptions on range of intermediate values require explicit `Inline Constraint`,
constructed use `constrain_as_size` and `constraint_as_value` APIs.
Expand Down Expand Up @@ -194,9 +266,9 @@ def specify_constraints(x):
constraints: An optional list of constraints on the dynamic arguments
that specify their possible range of shapes. By default, shapes of
input torch.Tensors are assumed to be static. If an input torch.Tensor
is expected to have dynamic shapes, please use `torch._export.dynamic_dim()`
is expected to have dynamic shapes, please use `torch.export.dynamic_dim()`
to define `Constraint` objects that specify the dynamics and the possible
range of shapes. See torch._export.dynamic_dim() docstring for examples on
range of shapes. See torch.export.dynamic_dim() docstring for examples on
how to use it.
Returns:
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
torch.export.dynamic_dim,
torch.export.export,
torch.eye,
torch.fft.fftfreq,
Expand Down

0 comments on commit 137d96a

Please sign in to comment.