Skip to content

Commit

Permalink
[reland] rename DisableTorchFunction to DisableTorchFunctionSubclass (p…
Browse files Browse the repository at this point in the history
…ytorch#88218) (pytorch#89221)

Summary: First half of pytorch#87990. This doesn't change any of the behavior and is just a rename

pytorch#88218 got reverted for internal breakages. This is the reland of started from internal

Differential Revision:
D41268423

LaMa Project: L1098534

Pull Request resolved: pytorch#89221
Approved by: https://github.com/meliy-meyada, https://github.com/zou3519
  • Loading branch information
samdow authored and pytorchmergebot committed Jan 4, 2023
1 parent a5e2309 commit a7749ae
Show file tree
Hide file tree
Showing 22 changed files with 42 additions and 41 deletions.
2 changes: 1 addition & 1 deletion test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@
"BFloat16Tensor",
"ComplexDoubleStorage",
"ComplexFloatStorage",
"DisableTorchFunction",
"DisableTorchFunctionSubclass",
"Generator",
"HalfStorage",
"HalfTensor",
Expand Down
2 changes: 1 addition & 1 deletion test/profiler/test_profiler_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
"<built-in method __exit__ of torch._C.DisableTorchFunction object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
"<built-in method __exit__ of torch._C.DisableTorchFunctionSubclass object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
"cudaStreamIsCapturing": PRUNE_ALL,
"cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL,
}
Expand Down
4 changes: 2 additions & 2 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ class B(torch.Tensor):

x = B(torch.randn(5))
with A():
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
self.assertNotIsInstance(torch.sum(x), B)

self.assertTrue(called)
Expand All @@ -1459,7 +1459,7 @@ class A(torch.Tensor):
pass

x = A(torch.randn(5))
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
g = torch._C._EnableTorchFunction()
try:
self.assertIsInstance(torch.sum(x), A)
Expand Down
2 changes: 1 addition & 1 deletion test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_no_new_bindings(self):
"device",
"DeviceObjType",
"DictType",
"DisableTorchFunction",
"DisableTorchFunctionSubclass",
"DispatchKey",
"DispatchKeySet",
"dtype",
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class layout:
...

# Defined in torch/csrc/utils/disable_torch_function.cpp
def DisableTorchFunction(): ...
def DisableTorchFunctionSubclass(): ...

# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def sym_int(a):
if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type]
if (obj.__module__ != 'torch'):
# TODO: fix their module from C++ side
if name not in ['DisableTorchFunction', 'Generator']:
if name not in ['DisableTorchFunctionSubclass', 'Generator']:
obj.__module__ = 'torch'

if not TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def _clone_input(value):
# The legacy behavior for real value cache with subclasses was
# to perform a clone WITHOUT preserving the subclass. It's
# not entirely clear this is what you actually want though.
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
proxy.tracer.real_value_cache[proxy.node] = _clone_input(
example_value
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def call_function(
options = VariableTracker.propagate(self, new_args, new_kwargs.values())
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], TorchVariable):
return wrap_fx_proxy(
tx=tx,
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def inline_torch_function_unwrapped(

# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
return tx.inline_user_function_return(tf_func_var, tf_args, {})


Expand Down
2 changes: 1 addition & 1 deletion torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,5 +1159,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
memo[id(tensor)] = out
return out
else:
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
2 changes: 1 addition & 1 deletion torch/_subclasses/meta_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def __call__(
# by hand, e.g., as is done in Dynamo
ctx = contextlib.nullcontext()
if ignore_subclass:
ctx = torch._C.DisableTorchFunction()
ctx = torch._C.DisableTorchFunctionSubclass()
with ctx:
r = self.meta_tensor(
t, shape_env=shape_env, callback=callback, source=source
Expand Down
2 changes: 1 addition & 1 deletion torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented

with _C.DisableTorchFunction():
with _C.DisableTorchFunctionSubclass():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
Expand Down
4 changes: 2 additions & 2 deletions torch/autograd/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
# TODO: Too slow with __torch_function__ handling enabled
# See https://github.com/pytorch/pytorch/issues/76410
if not torch.jit.is_scripting():
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
torch.ops.profiler._record_function_exit._RecordFunction(record)
else:
torch.ops.profiler._record_function_exit(record)
Expand Down Expand Up @@ -541,7 +541,7 @@ def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
# TODO: Too slow with __torch_function__ handling enabled
# See https://github.com/pytorch/pytorch/issues/76410
if not torch.jit.is_scripting():
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
record, fut)
else:
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,8 +1642,8 @@ Call this whenever a new thread is created in order to propagate values from
(PyObject*)THPDefaultCPUGenerator,
/* incref= */ false));
ASSERT_TRUE(set_module_attr(
"DisableTorchFunction",
(PyObject*)THPModule_DisableTorchFunctionType(),
"DisableTorchFunctionSubclass",
(PyObject*)THPModule_DisableTorchFunctionSubclassType(),
/* incref= */ false));
torch::set_disabled_torch_function_impl(
PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/autograd/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
_C_m, "_RestorePythonTLSSnapshot")
.def(py::init<>());

// TODO: line up this binding with DisableTorchFunction
py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
.def(py::init<>());
py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
Expand Down
32 changes: 17 additions & 15 deletions torch/csrc/utils/disable_torch_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
bool old_state;
} DisableTorchFunction;
} DisableTorchFunctionSubclass;

PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) {
((DisableTorchFunction*)self)->old_state =
PyObject* DisableTorchFunctionSubclass__enter(
PyObject* self,
PyObject* unused) {
((DisableTorchFunctionSubclass*)self)->old_state =
at::impl::PythonTorchFunctionTLS::is_disabled();
at::impl::PythonTorchFunctionTLS::set_disabled(true);
Py_RETURN_NONE;
}

PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) {
PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) {
at::impl::PythonTorchFunctionTLS::set_disabled(
((DisableTorchFunction*)self)->old_state);
((DisableTorchFunctionSubclass*)self)->old_state);
Py_RETURN_NONE;
}

Expand All @@ -58,16 +60,16 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) {
}
}

static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
{"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr},
{"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr},
static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
{"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
{"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};

PyTypeObject DisableTorchFunctionType = {
PyTypeObject DisableTorchFunctionSubclassType = {
PyVarObject_HEAD_INIT(
nullptr,
0) "torch._C.DisableTorchFunction", /* tp_name */
sizeof(DisableTorchFunction), /* tp_basicsize */
0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */
sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
Expand All @@ -92,7 +94,7 @@ PyTypeObject DisableTorchFunctionType = {
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
DisableTorchFunction_methods, /* tp_methods */
DisableTorchFunctionSubclass_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
Expand All @@ -105,12 +107,12 @@ PyTypeObject DisableTorchFunctionType = {
PyType_GenericNew, /* tp_new */
};

PyObject* THPModule_DisableTorchFunctionType() {
if (PyType_Ready(&DisableTorchFunctionType) < 0) {
PyObject* THPModule_DisableTorchFunctionSubclassType() {
if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) {
return nullptr;
}

return (PyObject*)(&DisableTorchFunctionType);
return (PyObject*)(&DisableTorchFunctionSubclassType);
}

PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/utils/disable_torch_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct DisableTorchDispatch {
} // namespace torch

PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused);
PyObject* THPModule_DisableTorchFunctionType();
PyObject* THPModule_DisableTorchFunctionSubclassType();
PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args);
PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args);
PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg);
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_shard/common_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def tensor_default_op(types, args=(), kwargs=None, pg=None):
Handles ``__torch_function__`` dispatch for the default tensor ops that
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
``torch.Tensor.dtype``. We simply lower to the real op call with
DisableTorchFunction context like ``torch.Tensor.__torch_function__``
DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
to avoid recursions.
"""
if kwargs is None:
kwargs = {}

with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
return op(*args, **kwargs)
2 changes: 1 addition & 1 deletion torch/distributed/_shard/partial_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def find_process_group(e):
# Need to disable all dispatch to print args and kwargs appropriately.
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for PartialTensor!")
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_shard/replicated_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def dispatch_arg(arg):
# We cann't do super().__torch_function__() as it implicitly convert the result
# back to tensor subclasses, where in our case, we need to control the output type
# base on the inter-op rules we defined.
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
rs = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return rs
Expand Down Expand Up @@ -157,7 +157,7 @@ def validate(self) -> bool:
return True

def __setstate__(self, state):
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
self.data = state
self.requires_grad = state.requires_grad
from torch.distributed._shard.api import _get_current_process_group
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
local_shard.tensor.requires_grad_(requires_grad)

# update the wrapper class property
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
self_st.requires_grad_(requires_grad)
# update the metadata in the meanwhile
self_st._metadata.tensor_properties.requires_grad = requires_grad
Expand Down
2 changes: 1 addition & 1 deletion torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

if not all(issubclass(cls, t) for t in types):
return NotImplemented
with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunctionSubclass():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
Expand Down

0 comments on commit a7749ae

Please sign in to comment.