Skip to content

Commit

Permalink
optests improvements based on torchvision usage on nms (pytorch#108929)
Browse files Browse the repository at this point in the history
- Update cross-ref FakeMode test to use ShapeEnv.  Dynamic ops can now
  return an unbacked SymInt.  We always accept this as equal to whatever
  the real value was.
- Relax test so it works on all classes, not just unittest.TestCase
- Properly wrap the original method, so things like
  pytree.mark.parametrize are carried over
- Support dynamic shapes by default for make_fx `tracing_mode="fake"` without symbolifying everything else

Fixes pytorch#108927

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#108929
Approved by: https://github.com/zou3519
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 13, 2023
1 parent bfa8429 commit 55f956f
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 60 deletions.
9 changes: 3 additions & 6 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,13 @@ def test_operator_compile_check_op(self, device, dtype, op):
for sample_input in op.sample_inputs(
device, dtype, requires_grad=op.supports_autograd
):
dynamic_only = op.name in ("NumpyNMSCustomOp", "NumpyNonzeroCustomOp")
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
operator_compile_check(
op.op,
args,
kwargs,
supports_autograd=op.supports_autograd,
dynamic_only=dynamic_only,
fullgraph=False, # Dynamo graph breaks on CustomOp today
)

Expand Down Expand Up @@ -1463,10 +1461,9 @@ def f(x):
return torch.ops._torch_testing.numpy_nonzero(x)

x = torch.randn(5, 5)
with self.assertRaises(
torch._subclasses.fake_tensor.DynamicOutputShapeException
):
make_fx(f, tracing_mode="fake")(x)
# We've updated to attempt to use unbacked symints even for fake
# tracing
make_fx(f, tracing_mode="fake")(x)

def test_symints(self):
def f(x):
Expand Down
7 changes: 1 addition & 6 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,16 +782,11 @@ def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
class FakeTensorOpInfoTest(TestCase):
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_fake(self, device, dtype, op):
data_dependent_outputs = {
'NumpyNMSCustomOp',
'NumpyNonzeroCustomOp',
}

sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
optests.fake_check(op, args, kwargs, op.name in data_dependent_outputs)
optests.fake_check(op, args, kwargs)


class FakeTensorConverterTest(TestCase):
Expand Down
10 changes: 2 additions & 8 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,7 @@ def f():
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)()

if self.tracing_mode == "fake":
self.assertRaises(DataDependentOutputException, test_f)
else:
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)

def test_constant_random(self):
def f():
Expand All @@ -530,10 +527,7 @@ def f():
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)()

if self.tracing_mode == "fake":
self.assertRaises(DataDependentOutputException, test_f)
else:
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)

def test_decomposition_interpreter(self):
def fn(x):
Expand Down
17 changes: 14 additions & 3 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,16 @@
TensorOrNumberLikeType = Union[TensorLikeType, NumberType]


def same_shape(a: ShapeType, b: ShapeType) -> bool:
def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
if len(a) != len(b):
return False

for x, y in zip(a, b):
if allow_rhs_unbacked:
# TODO: We should check that the symbols are consistent
# with each other
if isinstance(y, torch.SymInt):
continue
if x != y:
return False

Expand All @@ -90,7 +95,13 @@ def same_shape(a: ShapeType, b: ShapeType) -> bool:

# TODO: look at using torch.testing.assert_close instead with an option
# to just compare metadata
def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=False):
def compare_tensor_meta(
a: TensorLikeType,
b: TensorLikeType,
check_strides=False,
*,
allow_rhs_unbacked=False,
):
"""
Checks that two tensor likes have the same shape,
dtype and device.
Expand All @@ -101,7 +112,7 @@ def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=Fals
assert isinstance(a, TensorLike)
assert isinstance(b, TensorLike)

if not same_shape(a.shape, b.shape):
if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):
msg = f"Shapes {a.shape} and {b.shape} are not equal!"
raise AssertionError(msg)

Expand Down
9 changes: 8 additions & 1 deletion torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,10 +1234,15 @@ def __init__(
allow_fallback_kernels=True,
allow_non_fake_inputs=False,
shape_env=None,
static_shapes=None,
):
log.debug("create_mode 0x%x", id(self))
self.allow_fallback_kernels = allow_fallback_kernels
self.fake_tensor_converter = FakeTensorConverter()
if static_shapes is not None:
self.static_shapes = static_shapes
else:
self.static_shapes = shape_env is None

import torch._functorch.config

Expand Down Expand Up @@ -1758,7 +1763,7 @@ def from_tensor(
self,
tensor,
*,
static_shapes=False,
static_shapes=None,
ignore_subclass=False,
source: Optional[Source] = None,
dynamic_dims: Optional[DimList[DimDynamic]] = None,
Expand All @@ -1768,6 +1773,8 @@ def from_tensor(
memoized_only=False,
):
shape_env = self.shape_env
if static_shapes is None:
static_shapes = self.static_shapes
if static_shapes:
assert (
dynamic_dims is None
Expand Down
14 changes: 11 additions & 3 deletions torch/_subclasses/fake_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import warnings
from typing import Callable, Union

Expand All @@ -9,6 +10,7 @@
tree_flatten_only,
UnsupportedFakeTensorException,
)
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -79,9 +81,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
and torch.Tag.data_dependent_output not in func.tags
):
try:
with FakeTensorMode() as fake_mode:
# TODO: enable_python_dispatcher() here
with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
fake_args, fake_kwargs = pytree.tree_map_only(
torch.Tensor, fake_mode.from_tensor, (args, kwargs)
torch.Tensor,
functools.partial(fake_mode.from_tensor, static_shapes=True),
(args, kwargs),
)
with warnings.catch_warnings():
fake_r = func(*fake_args, **fake_kwargs)
Expand Down Expand Up @@ -143,7 +148,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):

try:
torch._prims.utils.compare_tensor_meta(
r_out, fake_out, check_strides=self.check_strides
r_out,
fake_out,
check_strides=self.check_strides,
allow_rhs_unbacked=True,
)
except Exception as e:
error_message = (
Expand Down
5 changes: 4 additions & 1 deletion torch/fx/experimental/proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,10 @@ def wrapped(*args):
if fake_tensor_mode is None:
fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=_allow_non_fake_inputs)
allow_non_fake_inputs=_allow_non_fake_inputs,
shape_env=ShapeEnv(),
static_shapes=True,
)
elif tracing_mode == "symbolic":
import torch._dynamo
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
Expand Down
5 changes: 2 additions & 3 deletions torch/testing/_internal/custom_op_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def numpy_nonzero_impl(x):
def numpy_nonzero_abstract(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
shape = [x.dim(), i0]
shape = [i0, x.dim()]
result = x.new_empty(shape, dtype=torch.long)
return result

Expand Down Expand Up @@ -337,7 +337,6 @@ def numpy_nms_impl(boxes, scores, iou_threshold):
inds = np.where(ovr <= iou_threshold)[0]
order = order[inds + 1]

result = np.stack(keep)
result = torch.tensor(np.stack(keep), device=device)
# Needed for data-dependent condition :(
assert result.size(0) >= 2
Expand All @@ -352,7 +351,7 @@ def numpy_nms_abstract(boxes, scores, iou_threshold):

ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
result = boxes.new_empty([i0, 4])
result = boxes.new_empty([i0], dtype=torch.int64)
return result

def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
Expand Down
10 changes: 2 additions & 8 deletions torch/testing/_internal/optests/compile_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def operator_compile_check(
args,
kwargs=None,
*,
dynamic_only=False,
supports_autograd=True,
fullgraph=True,
):
Expand All @@ -21,10 +20,6 @@ def operator_compile_check(
and returns a Tensor or a Tuple of Tensors.
args (Tuple): args to the operator
kwargs (dict, optional): kwargs to the operator
dynamic_only (bool, optional): If the operator only works with dynamic
shapes. This can happen if it returns Tensors whose shape are
dependent on the data on the input Tensors. If True, we skip
tests related to torch.compile with static shapes.
supports_autograd (bool, optional): If the operator does not support
autograd. If False, we will skip autograd-related tests.
fullgraph (bool, optional): If we expect torch.compile to not graph
Expand All @@ -48,9 +43,8 @@ def run_static_or_dynamic_tests(dynamic):
check_compile(func, args, kwargs, fullgraph=fullgraph, backend='inductor', dynamic=dynamic)

schema_check(func, args, kwargs)
fake_check(func, args, kwargs, dynamic_only)
if not dynamic_only:
run_static_or_dynamic_tests(dynamic=False)
fake_check(func, args, kwargs)
run_static_or_dynamic_tests(dynamic=False)
run_static_or_dynamic_tests(dynamic=True)


Expand Down
17 changes: 2 additions & 15 deletions torch/testing/_internal/optests/fake_tensor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
import torch._subclasses
from torch._subclasses.fake_tensor import DynamicOutputShapeException


def is_builtin(op):
return op.namespace in ('aten', 'prims', 'prim')


def fake_check(op, args, kwargs, dynamic_only):
def fake_check(op, args, kwargs):
with torch._subclasses.CrossRefFakeMode(ignore_op_fn=is_builtin):
try:
op(*args, **kwargs)
except DynamicOutputShapeException:
if not dynamic_only:
raise
return
if dynamic_only:
raise AssertionError(
f"fake_check({op}, ..., dynamic_only={dynamic_only}): "
f"dynamic_only means that the operator is expected to have "
f"data-dependent output shape. We have not detected that this is "
f"the case. Please check that your operator's FakeTensor "
f"implementation is actually data dependent")
op(*args, **kwargs)
8 changes: 2 additions & 6 deletions torch/testing/_internal/optests/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
import tempfile
import unittest

import torch

Expand Down Expand Up @@ -40,7 +39,7 @@ def safe_autograd_registration_check(op, args, kwargs):

def safe_fake_check(op, args, kwargs):
args, kwargs = deepcopy_tensors((args, kwargs))
return fake_check(op, args, kwargs, dynamic_only=False)
return fake_check(op, args, kwargs)


def safe_aot_autograd_check(op, args, kwargs, dynamic):
Expand Down Expand Up @@ -120,10 +119,6 @@ def generate_opcheck_tests(
failures_dict_path: See ``validate_failures_dict_structure`` for more details
test_utils: a list of test_utils to generate. Example: ["test_schema", "test_faketensor"]
"""
if not issubclass(testcase, unittest.TestCase):
raise ValueError(
f"Expected testcase to be subclass of unittest.TestCase, got {type(testcase)}"
)
test_methods = [
m
for m in dir(testcase)
Expand All @@ -139,6 +134,7 @@ def construct_method(attr, prefix, tester):
method = getattr(testcase, attr)
new_method_name = prefix + "__" + attr

@functools.wraps(method)
def new_method(*args, **kwargs):
with OpCheckMode(
namespaces,
Expand Down

0 comments on commit 55f956f

Please sign in to comment.