Skip to content

Commit

Permalink
Revert "rename DisableTorchFunction to DisableTorchFunctionSubclass (p…
Browse files Browse the repository at this point in the history
…ytorch#88218)"

This reverts commit 7f28be1.

Reverted pytorch#88218 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901
  • Loading branch information
pytorchmergebot committed Nov 11, 2022
1 parent 4e5d7af commit ba4d5aa
Show file tree
Hide file tree
Showing 20 changed files with 38 additions and 39 deletions.
2 changes: 1 addition & 1 deletion test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@
"BFloat16Tensor",
"ComplexDoubleStorage",
"ComplexFloatStorage",
"DisableTorchFunctionSubclass",
"DisableTorchFunction",
"Generator",
"HalfStorage",
"HalfTensor",
Expand Down
2 changes: 1 addition & 1 deletion test/profiler/test_profiler_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
"torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
"<built-in method __exit__ of torch._C.DisableTorchFunctionSubclass object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
"<built-in method __exit__ of torch._C.DisableTorchFunction object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
"cudaStreamIsCapturing": PRUNE_ALL,
"cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL,
}
Expand Down
4 changes: 2 additions & 2 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ class B(torch.Tensor):

x = B(torch.randn(5))
with A():
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
self.assertNotIsInstance(torch.sum(x), B)

self.assertTrue(called)
Expand All @@ -1460,7 +1460,7 @@ class A(torch.Tensor):
pass

x = A(torch.randn(5))
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
g = torch._C._EnableTorchFunction()
try:
self.assertIsInstance(torch.sum(x), A)
Expand Down
2 changes: 1 addition & 1 deletion test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_no_new_bindings(self):
"device",
"DeviceObjType",
"DictType",
"DisableTorchFunctionSubclass",
"DisableTorchFunction",
"DispatchKey",
"DispatchKeySet",
"dtype",
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class layout:
...

# Defined in torch/csrc/utils/disable_torch_function.cpp
def DisableTorchFunctionSubclass(): ...
def DisableTorchFunction(): ...

# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def get_pyobj(self):
if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type]
if (obj.__module__ != 'torch'):
# TODO: fix their module from C++ side
if name not in ['DisableTorchFunctionSubclass', 'Generator']:
if name not in ['DisableTorchFunction', 'Generator']:
obj.__module__ = 'torch'

if not TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def wrap_tensor(self, value: torch.Tensor):
)
# Disable __torch_function__ to prevent cloning of `value` to hit
# us
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
if is_constant_source(self.get_source()):
return self.tx.output.register_attr_or_module(
value,
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def call_function(
options = VariableTracker.propagate(self, new_args, new_kwargs.values())
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
if isinstance(args[0], TorchVariable):
return TensorVariable.create(
tx=tx,
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def inline_torch_function_unwrapped(

# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
return tx.inline_user_function_return(tf_func_var, tf_args, {})


Expand Down
2 changes: 1 addition & 1 deletion torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,5 +1093,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
memo[id(tensor)] = out
return out
else:
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
return func(*args, **kwargs)
2 changes: 1 addition & 1 deletion torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented

with _C.DisableTorchFunctionSubclass():
with _C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,8 +1594,8 @@ Call this whenever a new thread is created in order to propagate values from
(PyObject*)THPDefaultCPUGenerator,
/* incref= */ false));
ASSERT_TRUE(set_module_attr(
"DisableTorchFunctionSubclass",
(PyObject*)THPModule_DisableTorchFunctionSubclassType(),
"DisableTorchFunction",
(PyObject*)THPModule_DisableTorchFunctionType(),
/* incref= */ false));
torch::set_disabled_torch_function_impl(
PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
_C_m, "_RestorePythonTLSSnapshot")
.def(py::init<>());

// TODO: line up this binding with DisableTorchFunction
py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch")
.def(py::init<>());
py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
Expand Down
32 changes: 15 additions & 17 deletions torch/csrc/utils/disable_torch_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,18 @@ typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
bool old_state;
} DisableTorchFunctionSubclass;
} DisableTorchFunction;

PyObject* DisableTorchFunctionSubclass__enter(
PyObject* self,
PyObject* unused) {
((DisableTorchFunctionSubclass*)self)->old_state =
PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) {
((DisableTorchFunction*)self)->old_state =
at::impl::PythonTorchFunctionTLS::is_disabled();
at::impl::PythonTorchFunctionTLS::set_disabled(true);
Py_RETURN_NONE;
}

PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) {
PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) {
at::impl::PythonTorchFunctionTLS::set_disabled(
((DisableTorchFunctionSubclass*)self)->old_state);
((DisableTorchFunction*)self)->old_state);
Py_RETURN_NONE;
}

Expand All @@ -60,16 +58,16 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) {
}
}

static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
{"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
{"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
{"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr},
{"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};

PyTypeObject DisableTorchFunctionSubclassType = {
PyTypeObject DisableTorchFunctionType = {
PyVarObject_HEAD_INIT(
nullptr,
0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */
sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */
0) "torch._C.DisableTorchFunction", /* tp_name */
sizeof(DisableTorchFunction), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
Expand All @@ -94,7 +92,7 @@ PyTypeObject DisableTorchFunctionSubclassType = {
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
DisableTorchFunctionSubclass_methods, /* tp_methods */
DisableTorchFunction_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
Expand All @@ -107,12 +105,12 @@ PyTypeObject DisableTorchFunctionSubclassType = {
PyType_GenericNew, /* tp_new */
};

PyObject* THPModule_DisableTorchFunctionSubclassType() {
if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) {
PyObject* THPModule_DisableTorchFunctionType() {
if (PyType_Ready(&DisableTorchFunctionType) < 0) {
return nullptr;
}

return (PyObject*)(&DisableTorchFunctionSubclassType);
return (PyObject*)(&DisableTorchFunctionType);
}

PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/utils/disable_torch_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct DisableTorchDispatch {
} // namespace torch

PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused);
PyObject* THPModule_DisableTorchFunctionSubclassType();
PyObject* THPModule_DisableTorchFunctionType();
PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args);
PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args);
PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg);
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_shard/common_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def tensor_default_op(types, args=(), kwargs=None, pg=None):
Handles ``__torch_function__`` dispatch for the default tensor ops that
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
``torch.Tensor.dtype``. We simply lower to the real op call with
DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
DisableTorchFunction context like ``torch.Tensor.__torch_function__``
to avoid recursions.
"""
if kwargs is None:
kwargs = {}

with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
return op(*args, **kwargs)
2 changes: 1 addition & 1 deletion torch/distributed/_shard/partial_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def find_process_group(e):
# Need to disable all dispatch to print args and kwargs appropriately.
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for PartialTensor!")
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_shard/replicated_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def dispatch_arg(arg):
# We cann't do super().__torch_function__() as it implicitly convert the result
# back to tensor subclasses, where in our case, we need to control the output type
# base on the inter-op rules we defined.
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
rs = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return rs
Expand Down Expand Up @@ -157,7 +157,7 @@ def validate(self) -> bool:
return True

def __setstate__(self, state):
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
self.data = state
self.requires_grad = state.requires_grad
from torch.distributed._shard.api import _get_current_process_group
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
local_shard.tensor.requires_grad_(requires_grad)

# update the wrapper class property
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
self_st.requires_grad_(requires_grad)
# update the metadata in the meanwhile
self_st._metadata.tensor_properties.requires_grad = requires_grad
Expand Down
2 changes: 1 addition & 1 deletion torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

if not all(issubclass(cls, t) for t in types):
return NotImplemented
with torch._C.DisableTorchFunctionSubclass():
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
Expand Down

0 comments on commit ba4d5aa

Please sign in to comment.