Skip to content

Commit

Permalink
Add test for ShapeEnv state when not recording. (pytorch#109945)
Browse files Browse the repository at this point in the history
This PR adds a test for checking `ShapeEnv` state when it's built with
`should_record_events=False`.

Pull Request resolved: pytorch#109945
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#109904, pytorch#109944
  • Loading branch information
ysiraichi authored and pytorchmergebot committed Sep 26, 2023
1 parent 2ac7e52 commit 26e8cc0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 7 deletions.
38 changes: 35 additions & 3 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7315,11 +7315,43 @@ def fn(it):
self.assertEqual(list(eager), list(compiled))
self.assertEqual(counter.frame_count, 1)

def test_shape_env_no_recording(self):
main = ShapeEnv(should_record_events=False)

# The main ShapeEnv should have no event recorded.
self.assertEqual(len(main.events), 0)

# Call create_symbolic_sizes_strides_storage_offset on both of them.
r = main.create_symbolic_sizes_strides_storage_offset(
torch.randn(3, 2), ConstantSource("x")
)

# Create a guard: size[0] == 3 (call evaluate_expr)
# - +1 guard entry
# - +1 replacement entry
size = r[0]
bool(size[0] == 3)

# The main ShapeEnv should remain with no event recorded.
self.assertEqual(len(main.events), 0)

if torch.fx.experimental.validator.translation_validation_enabled():
from torch.fx.experimental.symbolic_shapes import (
CURRENT_NODE_KEY,
SHAPEENV_EVENT_KEY,
)

# Check that we don't store any recording metadata on nodes
# from the symbolic shape FX graph.
for n in main.graph.nodes:
self.assertFalse(SHAPEENV_EVENT_KEY in n.meta)
self.assertFalse(CURRENT_NODE_KEY in n.meta)

def _replay_and_check(self, shape_env: ShapeEnv):
replayed = replay_shape_env_events(shape_env.events)
shape_env.check_equal(replayed)
if shape_env.should_record_events:
replayed = replay_shape_env_events(shape_env.events)
shape_env.check_equal(replayed)

@onlyIfTranslationValidation
def test_shape_env_equal_empty(self):
main, other = ShapeEnv(), ShapeEnv()
main.check_equal(other)
Expand Down
6 changes: 4 additions & 2 deletions torch/fx/experimental/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,10 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value)
env2_vars = vars(env2).copy()

for v in non_state_variable_names:
env1_vars.pop(v)
env2_vars.pop(v)
if v in env1_vars:
env1_vars.pop(v)
if v in env2_vars:
env2_vars.pop(v)

# Function for transforming the mismatched values into string.
# Needed, since dict and set entries order might not be the same every time.
Expand Down
6 changes: 4 additions & 2 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2604,8 +2604,10 @@ def remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:

def add_fx_node_metadata(self, node: torch.fx.Node) -> None:
from torch._dynamo.utils import get_current_node
node.meta[SHAPEENV_EVENT_KEY] = self.last_event_index()
node.meta[CURRENT_NODE_KEY] = get_current_node()

if self.should_record_events:
node.meta[SHAPEENV_EVENT_KEY] = self.last_event_index()
node.meta[CURRENT_NODE_KEY] = get_current_node()

def _suppress_guards_tls(self):
return getattr(TLS, "suppress_guards", False)
Expand Down

0 comments on commit 26e8cc0

Please sign in to comment.