Skip to content

Commit

Permalink
Disable Constraint constructor (pytorch#107918)
Browse files Browse the repository at this point in the history
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Aug 26, 2023
1 parent f877d0a commit 27afb1c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
8 changes: 8 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import map
from torch import Tensor
from torch.export import Constraint
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, export
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.utils import (
Expand Down Expand Up @@ -999,6 +1000,13 @@ def f(x):
):
_ = export(f, (torch.tensor(6),))

def test_constraint_directly_construct(self):
with self.assertRaisesRegex(
TypeError,
"torch.export.Constraint has no public constructor. Please use torch.export.dynamic_dim"
):
_ = Constraint()


if __name__ == '__main__':
run_tests()
4 changes: 2 additions & 2 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.utils._pytree as pytree
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._dispatch.python import enable_python_dispatcher
from torch.export import Constraint
from torch.export import Constraint, _create_constraint
from torch._dynamo.exc import UserError, UserErrorType
from torch._dynamo.source import ConstantSource
from torch._export.exported_program import ModuleCallEntry, ModuleCallSignature
Expand Down Expand Up @@ -79,7 +79,7 @@ def dynamic_dim(t: torch.Tensor, index: int):
f" but got {index}, which is out of bounds for the given tensor."
)

return Constraint(
return _create_constraint(
weakref.ref(t),
id(t),
index,
Expand Down
28 changes: 23 additions & 5 deletions torch/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import io
import pathlib
import typing
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -581,10 +582,27 @@ class directly; instead, use :func:`torch.export.dynamic_dim`.
dim: int


# TODO(ycao): Disable constructor of Constraint so that it can only be constructed
# with dynamic_dim
class _ConstraintFactory(type):
"""
Metaclass that ensures a private constructor for Constraint
"""

def __call__(cls, *args, **kwargs):
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
f"Please use torch.export.dynamic_dim() to create one"
)

def _create(cls, w_tensor, t_id, dim, constraint_range, shared=None):
return super().__call__(w_tensor, t_id, dim, constraint_range, shared)


def _create_constraint(w_tensor, t_id, dim, constraint_range, shared=None):
return Constraint._create(w_tensor, t_id, dim, constraint_range, shared)


@dataclasses.dataclass
class Constraint(_ConstraintTarget):
class Constraint(_ConstraintTarget, metaclass=_ConstraintFactory):
"""
.. warning::
Expand All @@ -608,7 +626,7 @@ def _clone_with_range(self, lower=2, upper=sympy.oo):
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper),
warn_only=False,
)
return Constraint(
return _create_constraint(
self.w_tensor, self.t_id, self.dim, constraint_range, self.shared
)

Expand Down Expand Up @@ -669,7 +687,7 @@ def __eq__(self, other):
vr=self.constraint_range.vr & other.constraint_range.vr,
warn_only=False,
)
return Constraint(
return _create_constraint(
self.w_tensor,
self.t_id,
self.dim,
Expand Down

0 comments on commit 27afb1c

Please sign in to comment.