Skip to content

Commit

Permalink
[dynamo] Dont put nn module guards on torch inbuilt nn modules (pytor…
Browse files Browse the repository at this point in the history
…ch#110230)

This is one way to fix pytorch#110048

Looking for feedback.

Pull Request resolved: pytorch#110230
Approved by: https://github.com/ezyang
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Sep 29, 2023
1 parent 20dabea commit ce8b4f5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
29 changes: 29 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,35 @@ def test_unspec_non_inlinable_module(self):
expected = mod(x)
self.assertEqual(actual, expected)

def test_no_guard_on_torch_nn_modules(self):
# https://github.com/pytorch/pytorch/issues/110048

class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)

def forward(self, x):
return self.linear(x)

mod = MockModule()

cnt = torch._dynamo.testing.CompileCounter()

@torch.compile(backend=cnt)
def generate(x, c):
return mod(x) + c

for _ in range(0, 10):
generate(torch.randn(10, 10), 0)
generate(torch.randn(10, 10), 1)
self.assertEqual(cnt.frame_count, 2)

# Ensure that modification in user module causes recompile
mod.eval()
generate(torch.randn(10, 10), 0)
self.assertEqual(cnt.frame_count, 3)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from weakref import ReferenceType

from .allowed_functions import is_allowed

try:
import numpy as np
except ModuleNotFoundError:
Expand Down Expand Up @@ -468,6 +470,11 @@ def NN_MODULE(self, guard: Guard):
# leading to recompilations.
log.warning("Skipping nn module guard on LSTMs")
return

# Dynamo does not trace inside the inbuilt torch nn modules. Skip
# guarding on those. More rationale at https://github.com/pytorch/pytorch/issues/110048
if is_allowed(val.__class__):
return
try:
g = torch._C._dynamo.guards.nn_module_guard(val)
except AttributeError:
Expand Down

0 comments on commit ce8b4f5

Please sign in to comment.