Skip to content

Commit

Permalink
Re-land: Break graph on manual_seed. (pytorch#108647)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored and pytorchmergebot committed Sep 7, 2023
1 parent 9f37aec commit c887309
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 26 deletions.
9 changes: 9 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,15 @@ def test_numpy_random():
x = np.random.randn(2, 2)
return x - x

def test_manual_seed(self):
@torch.compile
def foo():
torch.manual_seed(3)
return torch.randint(0, 5, (5,))

self.assertEqual(foo(), foo())
self.assertEqual(foo(), foo())


def global_func_with_default_tensor_args(
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
Expand Down
16 changes: 6 additions & 10 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,14 +1814,6 @@ def fn(x, obj):
res = opt_fn(x, obj)
self.assertTrue(same(ref, res))

def test_manual_seed(self):
def fn(a, b):
x = a + b
torch.manual_seed(9000)
return x + 1

torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)

def test_usr_cls_staticmethod(self):
class Foo:
@staticmethod
Expand Down Expand Up @@ -2289,13 +2281,17 @@ def fn(x):
torch.manual_seed(attention_seed)
return (x,)

x = torch.randn(100, requires_grad=True)
x = torch.randn(10, requires_grad=True)
ref = fn(x)

opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
# Python code is needed here, since torch.manual_seed graph-breaks.
# Refs: https://github.com/pytorch/pytorch/issues/107187
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
res = opt_fn(x)

self.assertTrue(same(ref, res))
self.assertEqual(cnts.op_count, 1)
self.assertEqual(cnts.frame_count, 1)

def test_is_tensor_like(self):
cnts = torch._dynamo.testing.CompileCounter()
Expand Down
12 changes: 4 additions & 8 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def format_op(op):
"nn.functional.instance_norm": {f16},
"nn.functional.local_response_norm": {f16},
"nn.functional.normalize": {f16},
"nn.functional.rrelu": {f16, f32, f64},
"nn.functional.soft_margin_loss": {f16},
"nn.functional.softsign": {f16},
"nn.functional.triplet_margin_loss": {f16},
Expand All @@ -278,7 +277,6 @@ def format_op(op):
"sparse.sampled_addmm": {f32, f64},
("std_mean", "unbiased"): {f16},
"to_sparse": {f16, f32, f64},
"uniform": {f16, f32, f64},
}


Expand Down Expand Up @@ -342,15 +340,13 @@ def get_skips_and_xfails(from_dict, xfails=True):
)


def wrapper_set_seed(op, *args, **kwargs):
"""Wrapper to set seed manually for some functions like dropout
See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
"""
torch.manual_seed(42)
def wrapper_noop_set_seed(op, *args, **kwargs):
return op(*args, **kwargs)


torch.testing._internal.common_methods_invocations.wrapper_set_seed = wrapper_set_seed
torch.testing._internal.common_methods_invocations.wrapper_set_seed = (
wrapper_noop_set_seed
)

# This file does a global patch to `disable_global_flags()` - which we should not invoke in non testing cases.
torch._dynamo.variables.torch.tensor_dunder_fns.append(
Expand Down
18 changes: 11 additions & 7 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,17 @@ def _dtype(self):
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
_tensor_classes: Set[Type] = set()

################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices
################################################################################

# needs to be before from .functional import * to avoid circular dependencies
from ._compile import _disable_dynamo

################################################################################
# Import miscelaneous torch functions
################################################################################

# If you edit these imports, please update torch/__init__.py.in as well
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from .serialization import save, load
Expand Down Expand Up @@ -1381,13 +1392,6 @@ def manager_path():



################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices
################################################################################

# needs to be before from .functional import * to avoid circular dependencies
from ._compile import _disable_dynamo

################################################################################
# Import interface functions defined in Python
################################################################################
Expand Down
13 changes: 12 additions & 1 deletion torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import textwrap
import warnings
from enum import auto, Enum
from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import cast, Optional
Expand All @@ -12,7 +13,17 @@
from .utils import counters

if is_fbcode():
from torch.fb.exportdb.logging import exportdb_error_message
try:
from torch.fb.exportdb.logging import exportdb_error_message
except ModuleNotFoundError as err:
warnings.warn(
f"is_fbcode() is True, but could not import exportdb_error_message function: {err}. "
"Creating dummy replacement for it."
)

def exportdb_error_message(case_name):
return ""

else:

def exportdb_error_message(case_name):
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ def call_function(
elif self.value is torch.nn.Parameter:
# https://github.com/pytorch/pytorch/issues/99569
unimplemented("torch.nn.Parameter not supported")
elif self.value is torch.manual_seed:
# https://github.com/pytorch/pytorch/issues/107187
unimplemented("torch.manual_seed not supported")
if (
self.value.__name__ == "get_state"
and hasattr(self.value, "__self__")
Expand Down
1 change: 1 addition & 0 deletions torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_rng_state() -> torch.Tensor:
return default_generator.get_state()


@torch._disable_dynamo
def manual_seed(seed) -> torch._C.Generator:
r"""Sets the seed for generating random numbers. Returns a
`torch.Generator` object.
Expand Down

0 comments on commit c887309

Please sign in to comment.