Skip to content

Commit

Permalink
New export API with dynamic shape specifications instead of constrain…
Browse files Browse the repository at this point in the history
…ts (pytorch#108448)

Our experience using `constraints` / `dynamic_dim` with the existing export API has found it to be (subjectively) clunky and (objectively) verbose in common cases.

This PR implements a new design for the export API that replaces the use of `constraints` / `dynamic_dim` with a new way of specifying dynamic shapes, involving the following concepts:
* a constructor `Dim` for first-class named dynamic dimensions with ranges (similar to `functorch.dim`, and analogous to internal symbolic sizes)
* a mechanism that uses the above in `export` calls to associate inputs to their dynamic shape specifications (`dynamic_shapes`)

Design doc: https://docs.google.com/presentation/d/168U7XK72C_WSsZpGESP6Cho9udh193fi0gfjxCNcJ4E/edit#slide=id.p (Meta-only). Note that we only implement Option 1 in that doc. An older version of this PR also implemented Option 3, which is an alternative way of specifying dynamic shapes using tensor type annotations on the exported callable; but we have moved that to future work for now.

See docs for these new features in `torch.export`. The existing `torch.export.export` is modified to use the new API, `torch._export.export__RC__`, whenever `constraints=None`. We have not deprecated the existing API yet, but will do in a follow-up.

Constraint violation errors arising through use of the new API will now contain suggested fixes using the new API. No longer do we need to report all specializations for static dimensions and suggest all constraints over dynamic dimensions to fix such errors. Instead, due to the redesign, the suggested fixes are much more concise, only involving modifying the definitions of relevant `Dim`s.

Differential Revision: [D48919204](https://our.internmc.facebook.com/intern/diff/D48919204/)

Pull Request resolved: pytorch#108448
Approved by: https://github.com/suo, https://github.com/gmagogsfm
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Sep 22, 2023
1 parent cd99cdc commit ebc7039
Show file tree
Hide file tree
Showing 8 changed files with 645 additions and 67 deletions.
8 changes: 8 additions & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ Some additional things to note:
we see in the equality constraints the tuple specifying that ``arg5_1``
dimension 0 and ``arg6_1`` dimension 0 are equal.

(An experimental mechanism that is designed to eventually subsume the use of
:func:`torch.export.dynamic_dim` and ``constraints`` involves constructing
dynamic shape specifications with the :func:`torch.export.Dim` and
:func:`torch.export.dims` APIs and passing them into :func:`torch.export.export`
through the ``dynamic_shapes`` argument.)


Serialization
^^^^^^^^^^^^^
Expand Down Expand Up @@ -555,6 +561,8 @@ API Reference
.. autofunction:: save
.. autofunction:: load
.. autofunction:: register_dataclass
.. autofunction:: Dim
.. autofunction:: dims
.. autoclass:: Constraint
.. autoclass:: ExportedProgram

Expand Down
266 changes: 266 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,272 @@ def forward(self, x):
self.assertTrue("source_fn" in node.meta)
self.assertTrue("nn_module_stack" in node.meta)

def test_export_api_with_dynamic_shapes(self):
from torch.export import Dim, dims, export

# pass dynamic shapes of inputs [args]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
batch = Dim("batch")
efoo = export(foo, inputs, dynamic_shapes={k: {0: batch} for k in ["x", "y"]})
self.assertEqual(efoo(*inputs).shape, foo(*inputs).shape)

# pass dynamic shapes of inputs [kwargs]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 2, 3),)
kwinputs = {"y": torch.randn(10, 3, 4)}
batch = Dim("batch")
efoo = export(
foo, inputs, kwinputs, dynamic_shapes={k: {0: batch} for k in ["x", "y"]}
)
self.assertEqual(efoo(*inputs, **kwinputs).shape, foo(*inputs, **kwinputs).shape)

# pass dynamic shapes of inputs [partial, error]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 2, 3),)
kwinputs = {"y": torch.randn(10, 3, 4)}
batch = Dim("batch")
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(batch\\)!(.*\n)*.*"
"batch was inferred to be a constant(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"batch = None # 10"
),
):
export(foo, inputs, kwinputs, dynamic_shapes={"x": {0: batch}, "y": None})

# pass dynamic shapes of inputs [module]
class Foo(torch.nn.Module):
def forward(self, x, y):
return torch.matmul(x, y)

foo = Foo()
inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
batch = Dim("batch")
efoo = export(foo, inputs, dynamic_shapes={"x": {0: batch}, "y": {0: batch}})
self.assertEqual(efoo(*inputs).shape, foo(*inputs).shape)

# pass dynamic shapes of inputs [bounds, mostly shared]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 3, 3), torch.randn(10, 3, 3))
batch = Dim("batch", min=8, max=64)
size = Dim("size")
efoo = export(
foo,
inputs,
dynamic_shapes={
"x": (batch, size, size),
"y": (batch, size, size),
},
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"],
)
self.assertEqual(efoo(*inputs).shape, foo(*inputs).shape)

# pass dynamic shapes of inputs [multiple, mostly distinct]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
batch, M, K, N = dims("batch", "M", "K", "N")
efoo = export(
foo,
inputs,
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"],
)
self.assertEqual(efoo(*inputs).shape, foo(*inputs).shape)

# pass dynamic shapes of inputs [dict]
class Foo(torch.nn.Module):
def forward(self, inputs):
return torch.matmul(inputs["x"], inputs["y"])

foo = Foo()
inputs = ({"x": torch.randn(10, 2, 3), "y": torch.randn(10, 3, 4)},)
batch = Dim("batch")
efoo = export(
foo, inputs, dynamic_shapes={"inputs": {k: {0: batch} for k in ["x", "y"]}}
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
self.assertEqual(efoo(*inputs).shape, foo(*inputs).shape)

# pass dynamic shapes of inputs [list]
class Foo(torch.nn.Module):
def forward(self, inputs):
return torch.matmul(inputs[0], inputs[1])

foo = Foo()
inputs = ((torch.randn(10, 2, 3), torch.randn(10, 3, 4)),)
batch = Dim("batch")
efoo = export(
foo, inputs, dynamic_shapes={"inputs": [{0: batch} for _ in range(2)]}
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
self.assertEqual(efoo(*inputs).shape, foo(*inputs).shape)

# pass dynamic shapes of inputs [dataclass]
@dataclass
class DataClass:
a: Tensor
b: Tensor

register_dataclass_as_pytree_node(DataClass)

class Foo(torch.nn.Module):
def forward(self, inputs):
return torch.matmul(inputs.a, inputs.b)

foo = Foo()
inputs = (DataClass(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),)
batch = Dim("batch")
efoo = export(
foo, inputs, dynamic_shapes={"inputs": DataClass(a={0: batch}, b={0: batch})}
)
self.assertEqual(
[
str(node.meta["val"].shape)
for node in efoo.graph_module.graph.nodes
if node.op == "placeholder"
],
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)

# pass dynamic shapes of inputs [distinct, error]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
batch, M, K1, K2, N = dims("batch", "M", "K1", "K2", "N")
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(K2\\)!(.*\n)*.*"
"K2.*and.*K1.*must always be equal(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"K2 = K1"
),
):
export(
foo,
inputs,
dynamic_shapes={"x": (batch, M, K1), "y": (batch, K2, N)},
)

# pass dynamic shapes of inputs [specialized, error]
def foo(x, y):
return torch.matmul(x, y)

inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
batch, M, K1, N = dims("batch", "M", "K1", "N")
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(K1\\)!(.*\n)*.*"
"K1 was inferred to be a constant(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"K1 = None # 3"
),
):
export(
foo,
inputs,
dynamic_shapes={"x": (batch, M, K1), "y": (batch, None, N)},
)

# pass dynamic shapes of inputs [guards, error]
def foo(x, y):
if x.shape[0] < 16 and y.shape[1] % 3 == 0:
return torch.matmul(x, y)
else:
return x + y

inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4))
batch, M, K, N = dims("batch", "M", "K", "N")
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Constraints violated \\(batch\\)!(.*\n)*.*"
"Not all values of batch.*satisfy the generated guard(.*\n)*.*"
"Specializations unexpectedly required \\(K\\)!(.*\n)*.*"
"K.*specialized.*because the guards generated for it are too complex(.*\n)*.*"
"Suggested fixes:(.*\n)*.*"
"batch = Dim\\('batch', max=15\\)(.*\n)*.*"
"K = None # 3"
),
):
export(
foo,
inputs,
dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)},
)

def test_dynamic_shapes_spec_with_pytree(self):
from torch.export import Dim, export
from torch.utils._pytree import tree_map

inputs = {
"tensor": torch.randn(3),
"dict_of_tensors": {k: torch.randn(3) for k in ["A", "B", "C", "D"]},
"list_of_tensors": [torch.randn(3) for _ in range(4)],
}

batch = Dim("batch")
# uniformly specify dynamic shapes for all inputs
spec = tree_map(lambda x: {0: batch}, inputs)

def foo(inputs):
return (
inputs["tensor"]
+ inputs["dict_of_tensors"]["A"]
+ inputs["list_of_tensors"][0]
)

ep = export(foo, (inputs,), dynamic_shapes={"inputs": spec})
input_shapes = [
str(node.meta["val"].shape)
for node in ep.graph_module.graph.nodes
if node.op == "placeholder"
]
self.assertEqual(len(input_shapes), 9)
self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes))

def test_error_does_not_reference_eager_fallback(self):
def fn_ddo(x):
Expand Down
4 changes: 2 additions & 2 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def test_dim_constraints_reduce_congruences_simple(self):
from torch.fx.experimental.symbolic_shapes import DimConstraints

s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {}, set())
dim_constraints = DimConstraints({}, {}, set(), {})
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def test_dim_constraints_solve_full(self):
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic)
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic, {})
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
Expand Down
10 changes: 3 additions & 7 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,14 +1225,10 @@ def result_capturing_wrapper(*graph_inputs):
):
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
if forced_specializations:
msg = (
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify.\n"
f"{forced_specializations}\n{msg}"
)
msg = dim_constraints.prettify_results(
original_signature, constraint_violation_error, forced_specializations
)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
Expand Down
Loading

0 comments on commit ebc7039

Please sign in to comment.