Skip to content

Commit

Permalink
Move Constraint class to torch.export() to avoid circular dependenc…
Browse files Browse the repository at this point in the history
  • Loading branch information
gmagogsfm authored and pytorchmergebot committed Aug 24, 2023
1 parent 7bab98f commit f8119f8
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 113 deletions.
1 change: 1 addition & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ torch.export
.. autofunction:: dynamic_dim
.. autofunction:: constrain_as_size
.. autofunction:: constrain_as_value
.. autoclass:: Constraint

.. toctree::
:glob:
Expand Down
111 changes: 2 additions & 109 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import dataclasses
import dis
import functools
import inspect
Expand All @@ -13,7 +12,6 @@
import traceback
import types
import warnings
import weakref
from collections import namedtuple
from enum import Enum
from os.path import dirname, join
Expand All @@ -37,6 +35,7 @@
import torch.utils.checkpoint
from torch import _guards
from torch._subclasses import fake_tensor
from torch.export import Constraint
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.nn.parallel.distributed import DistributedDataParallel
Expand Down Expand Up @@ -77,11 +76,7 @@

import sympy

from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
StrictMinMaxConstraint,
)
from torch.utils._sympy.value_ranges import ValueRanges
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError


# See https://github.com/python/typing/pull/240
Expand Down Expand Up @@ -704,108 +699,6 @@ def guard_export_print(guards):
return inner


@dataclasses.dataclass
class ConstraintTarget:
"""
This represents input tensor dimensions. Don't create this
class directly; instead, use :func:`torch._export.dynamic_dim`.
"""

w_tensor: weakref.ReferenceType[torch.Tensor]
# TODO: We don't need t_id; we can get it off of w_tensor
t_id: int
dim: int


@dataclasses.dataclass
class Constraint(ConstraintTarget):
"""
This represents constraints on input tensor dimensions, e.g., requiring
them to be fully polymorphic or within some range. Don't create this
class directly; instead, use :func:`torch._export.dynamic_dim`.
"""

# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
constraint_range: StrictMinMaxConstraint
# Represent that `constraint_range` is shared with another ConstraintTarget, which
# typically arises because of a specified equality with another dynamic dimension.
shared: Optional[ConstraintTarget] = None

def _clone_with_range(self, lower=2, upper=sympy.oo):
constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
warn_only=False,
)
return Constraint(
self.w_tensor, self.t_id, self.dim, constraint_range, self.shared
)

def __ge__(self, lower):
return self._clone_with_range(lower=lower)

def __gt__(self, lower):
return self._clone_with_range(lower=lower + 1)

def __le__(self, upper):
return self._clone_with_range(upper=upper)

def __lt__(self, upper):
return self._clone_with_range(upper=upper - 1)

def __bool__(self):
# NOTE(avik): We do not support compound expressions like a <= x <= b.
# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
# and moreover, enforces that any overload of __bool__ must return True or False.
# FWIW, sympy also raises TypeError in this case.
raise TypeError(
"Cannot determine truth value of Constraint. "
"If you are trying to combine Constraints with logical connectives, "
"you can specify them separately instead."
)

@property
def serializable_spec(self):
# We need a serialization compatible format of the constraint so that it
# can be savedin the graph module w/o breaking the module serialization.
# The saved constraints will be used directly for the post-exporting pass
# that converts constraints to runtime assertion. The saved constraints
# will not be saved in the serialized module.
# TODO: A better way is needed. Currently we use 't_id' to map the constraint,
# which is not reliable
return {
"t_id": self.t_id,
"dim": self.dim,
"min": self.constraint_range.vr.lower,
"max": self.constraint_range.vr.upper,
"shared": (
None
if self.shared is None
else {
"t_id": self.shared.t_id,
"dim": self.shared.dim,
}
),
}

def __eq__(self, other):
if not isinstance(other, Constraint):
raise TypeError(
"A dynamic dim can be specified equal only to another dynamic dim. "
f"Equality with {type(other)} is not supported."
)
constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & other.constraint_range.vr,
warn_only=False,
)
return Constraint(
self.w_tensor,
self.t_id,
self.dim,
constraint_range,
shared=ConstraintTarget(other.w_tensor, other.t_id, other.dim),
)


class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.utils._pytree as pytree
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.eval_frame import Constraint
from torch.export import Constraint
from torch._dynamo.exc import UserError, UserErrorType
from torch._dynamo.source import ConstantSource
from torch._export.exported_program import ModuleCallEntry, ModuleCallSignature
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/db/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
from torch._dynamo.eval_frame import Constraint
from torch.export import Constraint

_TAGS: Dict[str, Dict[str, Any]] = {
"torch": {
Expand Down
119 changes: 117 additions & 2 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,132 @@
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Tuple

import sympy

import torch

from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint


__all__ = [
"Constraint",
"constrain_as_size",
"constrain_as_value",
"dynamic_dim",
"export",
]


@dataclasses.dataclass
class _ConstraintTarget:
"""
This represents input tensor dimensions. Don't create this
class directly; instead, use :func:`torch.export.dynamic_dim`.
"""

w_tensor: Any # weakref to torch.Tensor
# TODO: We don't need t_id; we can get it off of w_tensor
t_id: int
dim: int


# TODO(ycao): Disable constructor of Constraint so that it can only be constructed
# with dynamic_dim
@dataclasses.dataclass
class Constraint(_ConstraintTarget):
"""
.. warning::
Do not construct `Constraint` directly, use :func:`torch.export.dynamic_dim` instead.
This represents constraints on input tensor dimensions, e.g., requiring
them to be fully polymorphic or within some range.
"""

# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>]
constraint_range: StrictMinMaxConstraint
# Represent that `constraint_range` is shared with another _ConstraintTarget, which
# typically arises because of a specified equality with another dynamic dimension.
shared: Optional[_ConstraintTarget] = None

def _clone_with_range(self, lower=2, upper=sympy.oo):
from torch.utils._sympy.value_ranges import ValueRanges

constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
warn_only=False,
)
return Constraint(
self.w_tensor, self.t_id, self.dim, constraint_range, self.shared
)

def __ge__(self, lower):
return self._clone_with_range(lower=lower)

def __gt__(self, lower):
return self._clone_with_range(lower=lower + 1)

def __le__(self, upper):
return self._clone_with_range(upper=upper)

def __lt__(self, upper):
return self._clone_with_range(upper=upper - 1)

def __bool__(self):
# NOTE(avik): We do not support compound expressions like a <= x <= b.
# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b),
# and moreover, enforces that any overload of __bool__ must return True or False.
# FWIW, sympy also raises TypeError in this case.
raise TypeError(
"Cannot determine truth value of Constraint. "
"If you are trying to combine Constraint's with logical connectives, "
"you can specify them separately instead."
)

@property
def serializable_spec(self):
# We need a serialization compatible format of the constraint so that it
# can be savedin the graph module w/o breaking the module serialization.
# The saved constraints will be used directly for the post-exporting pass
# that converts constraints to runtime assertion. The saved constraints
# will not be saved in the serialized module.
# TODO: A better way is needed. Currently we use 't_id' to map the constraint,
# which is not reliable
return {
"t_id": self.t_id,
"dim": self.dim,
"min": self.constraint_range.vr.lower,
"max": self.constraint_range.vr.upper,
"shared": (
None
if self.shared is None
else {
"t_id": self.shared.t_id,
"dim": self.shared.dim,
}
),
}

def __eq__(self, other):
if not isinstance(other, Constraint):
raise TypeError(
"A dynamic dim can be specified equal only to another dynamic dim. "
f"Equality with {type(other)} is not supported."
)
constraint_range = StrictMinMaxConstraint(
vr=self.constraint_range.vr & other.constraint_range.vr,
warn_only=False,
)
return Constraint(
self.w_tensor,
self.t_id,
self.dim,
constraint_range,
shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim),
)


def constrain_as_value(symbol, min: Optional[int] = None, max: Optional[int] = None):
"""
Hint `export()` about the constraint of an intermediate scalar value so that subsequent
Expand All @@ -31,7 +147,6 @@ def fn(x):
else:
return x * 2
`export()` would give following error::
torch._dynamo.exc.UserError: Consider annotating your code using
Expand Down Expand Up @@ -246,7 +361,7 @@ def export(
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
constraints: Optional[List["torch._dynamo.eval_frame.Constraint"]] = None,
constraints: Optional[List[Constraint]] = None,
) -> "torch._export.exported_program.ExportedProgram": # type: ignore[name-defined]
"""
`export()` is a one-shot process for capturing a computation graph from
Expand Down

0 comments on commit f8119f8

Please sign in to comment.