forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
optests improvements based on torchvision usage on nms (pytorch#108929)
- Update cross-ref FakeMode test to use ShapeEnv. Dynamic ops can now return an unbacked SymInt. We always accept this as equal to whatever the real value was. - Relax test so it works on all classes, not just unittest.TestCase - Properly wrap the original method, so things like pytree.mark.parametrize are carried over - Support dynamic shapes by default for make_fx `tracing_mode="fake"` without symbolifying everything else Fixes pytorch#108927 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#108929 Approved by: https://github.com/zou3519
- Loading branch information
1 parent
bfa8429
commit 55f956f
Showing
11 changed files
with
51 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,10 @@ | ||
import torch._subclasses | ||
from torch._subclasses.fake_tensor import DynamicOutputShapeException | ||
|
||
|
||
def is_builtin(op): | ||
return op.namespace in ('aten', 'prims', 'prim') | ||
|
||
|
||
def fake_check(op, args, kwargs, dynamic_only): | ||
def fake_check(op, args, kwargs): | ||
with torch._subclasses.CrossRefFakeMode(ignore_op_fn=is_builtin): | ||
try: | ||
op(*args, **kwargs) | ||
except DynamicOutputShapeException: | ||
if not dynamic_only: | ||
raise | ||
return | ||
if dynamic_only: | ||
raise AssertionError( | ||
f"fake_check({op}, ..., dynamic_only={dynamic_only}): " | ||
f"dynamic_only means that the operator is expected to have " | ||
f"data-dependent output shape. We have not detected that this is " | ||
f"the case. Please check that your operator's FakeTensor " | ||
f"implementation is actually data dependent") | ||
op(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters