Skip to content

Commit

Permalink
Handle constant SymBool in unary and binary operations (pytorch#109169)
Browse files Browse the repository at this point in the history
In this PR:
- When Constant SymNode are detected in unary/binary ops demote them to plain int/bool before proceeding. Sometimes this means doing a unary op with a Constant SymNode would result in a plain bool.
- Introduce an is_symbolic method, only available from Python. We need this because isinstance(x, SymInt) is no longer sufficient to check whether a given int/SymInt is symbolic or not. See later PR in the stack to see how this is used.
Pull Request resolved: pytorch#109169
Approved by: https://github.com/ezyang
  • Loading branch information
soulitzer authored and pytorchmergebot committed Sep 20, 2023
1 parent 8597d37 commit 5252fcb
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 6 deletions.
6 changes: 6 additions & 0 deletions c10/core/ConstantSymNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class C10_API ConstantSymNodeImpl : public SymNodeImpl {
return c10::nullopt;
}
}
bool is_constant() override {
return true;
}
bool is_symbolic() override {
return false;
}

private:
c10::variant<int64_t, bool> value_;
Expand Down
4 changes: 4 additions & 0 deletions c10/core/SingletonSymNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class C10_API SingletonSymNodeImpl : public SymNodeImpl {
return val_;
}

bool is_symbolic() override {
return false;
}

#define DEFINE_BINARY_NOT_SUPPORTED(name) \
c10::SymNode name(const c10::SymNode& other) override { \
TORCH_CHECK(false, #name " not supported by SingletonSymNode"); \
Expand Down
6 changes: 6 additions & 0 deletions c10/core/SymNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual c10::optional<int64_t> maybe_as_int() {
return c10::nullopt;
}
virtual bool is_constant() {
return false;
}
virtual bool is_symbolic() {
return true;
}
std::ostream& operator<<(std::ostream& os) {
os << str();
return os;
Expand Down
65 changes: 65 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
sym_sqrt,
SymNode,
to_node,
is_symbolic,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -714,6 +715,70 @@ def test_method(self, fn, first_type, second_type):

self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)

def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))

def test_non_symbolic_symnode(self):
j1 = torch._C._get_singleton_int(1)
j2 = torch._C._get_singleton_int(1)
j3 = torch._C._get_singleton_int(3)

self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)

with self.assertRaisesRegex(RuntimeError, "add not supported by SingletonSymNode"):
j1 + 3

self.assertFalse(j1 == 3)
self.assertFalse(3 >= j2)

self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)
self.assertIs(j1 == j3, False)
self.assertIs(j1 != j3, True)
self.assertIs(j1 != j2, False)

x = self.get_constant_bool(True)
#
# Unary
#
# op(constant SymBool)
self.assertIs(x.__sym_not__(), False)

#
# Binary
#
# op(constant SymBool, bool)
# op(constant SymBool, constant SymBool)
# op(bool, constant SymBool)
self.assertIs(operator.and_(x, True), True)
self.assertIs(operator.and_(x, x), True)
self.assertIs(operator.and_(True, x), True)

# op(symbolic SymBool, constant Symbool)
# op(constant SymBool, symbolic Symbool)
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
b = create_symint(shape_env, 2)
c = a == b # symbolic SymBool
d = self.get_constant_bool(True)
e = operator.and_(c, d)
f = operator.and_(d, c)
self.assertTrue(is_symbolic(e))
self.assertTrue(is_symbolic(f))
self.assertIs(e.node.guard_bool("", 0), True)
self.assertIs(f.node.guard_bool("", 0), True)

# Comparing sizes
sz1 = torch.Size([j1, j1, j1])
sz2 = torch.Size([j1, j1, j1])
self.assertIs(sz1 == sz2, True)

sz1 = torch.Size([3, j1, 4])
sz2 = torch.Size([3, j2, 4])
self.assertIs(sz1 == sz2, True)
self.assertIs(sz1 != sz2, False)

instantiate_parametrized_tests(TestSymNumberMagicMethods)

class TestFloorDiv(TestCase):
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,8 @@ def _set_python_dispatcher(dispatcher: object) -> None: ...

def _get_singleton_int(id: _int) -> SymInt: ...

def _get_constant_bool_symnode(val: _bool) -> Any: ...

class _TorchDispatchModeKey(Enum):
${torch_dispatch_mode_key_hints}

Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __sym_not__(self) -> "SymBool":
raise AssertionError("type stub not overridden")

def __repr__(self):
return self.node.str()
return str(self.node)

def sym_not(a):
r""" SymInt-aware utility for logical negation.
Expand Down
17 changes: 14 additions & 3 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,9 +1253,20 @@ void initJITBindings(PyObject* module) {
.def(
"__str__",
[](c10::SymNode a) { return a->str(); })
.def("__repr__", [](c10::SymNode a) {
return a->str();
});
.def(
"__repr__",
[](c10::SymNode a) { return a->str(); })
.def(
"is_constant",
[](const c10::SymNode& node){
return node->is_constant();
})
.def(
"is_symbolic",
[](const c10::SymNode& node) {
return node->is_symbolic();
});

// clang-format on

// NOLINTNEXTLINE(bugprone-unused-raii)
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/utils/python_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,11 @@ void initDispatchBindings(PyObject* module) {
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(data)));
});

m.def("_get_constant_bool_symnode", [](int64_t data) {
return c10::SymNode(
c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
});

using c10::impl::TorchDispatchModeKey;
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
Expand Down
50 changes: 48 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,19 @@ def expect_true(self, file, line):
def bool_(self):
return self.guard_bool("", 0)

def is_symbolic(self):
return True

def is_singleton_int(self):
return False

def is_constant(self):
return False

def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
if isinstance(val, (int, float, bool)):
return False
return val.node.is_symbolic()

# Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900
Expand Down Expand Up @@ -1501,20 +1514,53 @@ def _make_user_magic(method, user_type):
else:
method_attr = method

def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
if isinstance(x, (int, float, bool)):
return x
if isinstance(x, SymBool):
return x.node.guard_bool("", 0)
raise AssertionError("expect to be called with constant SymBools")

def is_constant(x):
if isinstance(x, (int, float, bool)):
return True
if isinstance(x, (SymInt, SymFloat, SymBool)):
return x.node.is_constant()
return False

# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
# this means that operations involving SymBool return plain bools.
# Alternatively, we could also rewrap into constant Symbool (i.e. by
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
# today for no particular reason.
def unary_magic_impl(self):
if is_constant(self):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())

def binary_magic_impl(self, other):
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(self.node, method_attr)(other_node))
ret = wrap_node(getattr(self.node, method_attr)(other_node))
return get_constant(ret) if is_constant(ret) else ret

def rbinary_magic_impl(self, other):
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(other_node, method_attr)(self.node))
ret = wrap_node(getattr(other_node, method_attr)(self.node))
return get_constant(ret) if is_constant(ret) else ret

if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
Expand Down

0 comments on commit 5252fcb

Please sign in to comment.