Skip to content

Commit

Permalink
enforce equalities (pytorch#108429)
Browse files Browse the repository at this point in the history
Sometimes one might want to impose equalities that are not required by guards, e.g. say that you only want square images when rectangular images would suffice.

Curiously we never checked that the concrete values passed in example shapes actually satisfy such equality constraints. So, e.g., you could multiply two tensors of shapes MxK and KxN, specify that M and N must be equal, and then pass examples where they are not equal.

Relatedly, the symbolic shape dimensions for inputs in the exported graph were not forced to be equal.

However, runtime assertions still fire because they take into account all equality constraints. This would result in the strange situation where export would succeed but the exported program with the same example inputs would fail.

This PR fixes these issues.

Differential Revision: [D48910918](https://our.internmc.facebook.com/intern/diff/D48910918/)

Pull Request resolved: pytorch#108429
Approved by: https://github.com/zhxchen17
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Sep 7, 2023
1 parent 247c603 commit c55cb29
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,6 +2378,44 @@ def foo(x, y):
foo, (a, {"k": b}), constraints=[dynamic_dim(a, 0), dynamic_dim(b, 0)]
)

def test_enforce_equalities(self):
def bar(x, y):
return torch.matmul(x, y)

def specify_constraints(x, y):
return [
dynamic_dim(x, 0) == dynamic_dim(y, 0),
dynamic_dim(x, 1) == dynamic_dim(x, 2),
dynamic_dim(x, 2) == dynamic_dim(y, 1),
dynamic_dim(y, 1) == dynamic_dim(y, 2),
]

x = torch.randn(10, 3, 3)
y = torch.randn(10, 3, 4)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
".*y.*size.*1.* = 3 is not equal to .*y.*size.*2.* = 4",
):
torch._export.export(
bar,
(x, y),
constraints=specify_constraints(x, y),
)
y = torch.randn(10, 3, 3)
ebar = torch._export.export(
bar,
(x, y),
constraints=specify_constraints(x, y),
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in ebar.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)

@config.patch(
capture_dynamic_output_shape_ops=True,
specialize_int=True,
Expand Down
21 changes: 21 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,6 +2743,27 @@ def record_constraint_violation(warn_only, msg, hint=None):
def is_dim(src):
return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE

if equalities_inputs:
source_index = {}
for i, src in enumerate(sources):
source_index[src.name()] = i

def get_symbol(tensor_dim_src):
fake = placeholders[source_index[tensor_dim_src.base.name()]]
symint = fake.shape[tensor_dim_src.idx]
assert isinstance(symint, torch.SymInt)
return symint.node.expr

for src1, src2 in equalities_inputs.source_pairs:
s1, s2 = get_symbol(src1), get_symbol(src2)
concrete_val = self.evaluate_expr(sympy.Eq(s1, s2))
if not concrete_val:
raise ConstraintViolationError(
f"{src1.name()} = {self.var_to_val[s1]}"
" is not equal to "
f"{src2.name()} = {self.var_to_val[s2]}"
)

# How do we know what the value of s0 is? Fresh variables can only be
# bound by inputs, so there MUST be some other input which binds the
# variable. If there is no such input, this is an error in our
Expand Down

0 comments on commit c55cb29

Please sign in to comment.