Skip to content

Commit

Permalink
Make segment_reduce properly private. (pytorch#93166)
Browse files Browse the repository at this point in the history
I am attempting not to change the aten function to reduce the amount of BC issues on the torchscript side.

Pull Request resolved: pytorch#93166
Approved by: https://github.com/ngimel
  • Loading branch information
albanD authored and pytorchmergebot committed Feb 6, 2023
1 parent 9b3277c commit 496c0a2
Show file tree
Hide file tree
Showing 16 changed files with 47 additions and 33 deletions.
4 changes: 2 additions & 2 deletions test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ def wrapped(fn):
skip("masked.std"),
skip("masked.normalize"),
skip("prod"),
skip("segment_reduce", "lengths"),
skip("segment_reduce", "offsets"),
skip("_segment_reduce", "lengths"),
skip("_segment_reduce", "offsets"),

# TODO: fix the following ops
skip("squeeze"),
Expand Down
6 changes: 3 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,7 +2227,7 @@ def forward(self, x):

# Worked with real but not with fake
xfail('cholesky_inverse'),
xfail('segment_reduce', 'lengths'),
xfail('_segment_reduce', 'lengths'),
skip('nn.functional.nll_loss', ''), # UBSAN failure!

# Misc
Expand Down Expand Up @@ -2399,8 +2399,8 @@ def forward(self, x):
xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
xfail('repeat_interleave', ''), # aten.repeat_interleave.Te...
xfail('roll', ''), # narrow() received an invalid combination of arguments - got (FakeTensor, int, torch._C...
xfail('segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
Expand Down
12 changes: 6 additions & 6 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,8 +1176,8 @@ def test():
xfail('index_reduce', ''),
xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''),
xfail('segment_reduce', 'offsets'),
xfail('segment_reduce', 'lengths'),
xfail('_segment_reduce', 'offsets'),
xfail('_segment_reduce', 'lengths'),
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
Expand Down Expand Up @@ -1349,9 +1349,9 @@ def get_vjp(cotangents, *primals):
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce
xfail('_segment_reduce', 'offsets'), # NYI: forward-AD for _segment_reduce
xfail('index_reduce', ''), # NYI: forward-AD for index_reduce
xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce
xfail('_segment_reduce', 'lengths'), # NYI: forward-AD for _segment_reduce
xfail('native_dropout_backward'), # NYI
}))
Expand Down Expand Up @@ -1502,8 +1502,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
xfail('quantile'), # Batching rule not implemented for aten::equal
xfail('renorm'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition
xfail('segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
xfail('segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
xfail('sparse.sampled_addmm'), # RuntimeError: Sparse CSR tensors do not have strides
xfail('svd_lowrank'), # calls random op
xfail('take'), # vmap: inplace into regular tensor
Expand Down
4 changes: 2 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3704,15 +3704,15 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('special.bessel_y0'),
xfail('special.chebyshev_polynomial_u'),
xfail('special.modified_bessel_k1'),
xfail('segment_reduce', 'offsets'),
xfail('_segment_reduce', 'offsets'),
xfail('special.bessel_j1'),
xfail('index_reduce', ''),
xfail('special.laguerre_polynomial_l'),
xfail('special.hermite_polynomial_h'),
xfail('jiterator_binary', device_type='cuda'),
xfail('special.modified_bessel_i0'),
xfail('jiterator_4inputs_with_extra_args', device_type='cuda'),
xfail('segment_reduce', 'lengths'),
xfail('_segment_reduce', 'lengths'),
xfail('lu_solve', ''),
xfail('special.bessel_y1'),
xfail('special.hermite_polynomial_he'),
Expand Down
4 changes: 2 additions & 2 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def process(device_type):
"scatter_add": {f16},
"scatter_reduce.sum": {f16},
"scatter_reduce.prod": {f16, f32, f64},
"segment_reduce.lengths": {f16, f32, f64},
"_segment_reduce.lengths": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"stft": {f32, f64},
"tensor_split": {b8, f16, f32, f64, i32, i64},
Expand Down Expand Up @@ -317,7 +317,7 @@ def process(device_type):
"repeat_interleave": {b8, f16, f32, f64, i32, i64},
"round.decimals_3": {f16},
"scatter_reduce.prod": {f16, f32, f64},
"segment_reduce.lengths": {f16, f32, f64},
"_segment_reduce.lengths": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"std_mean.unbiased": {f16},
"stft": {f32, f64},
Expand Down
2 changes: 1 addition & 1 deletion test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def run_meta_crossref(
torch.nn.functional.one_hot : {i64},
torch.nn.functional.pdist : {f64, f32},
torch.polar : {f64, f32},
torch.segment_reduce : {f64, f16, bf16, f32},
torch._segment_reduce : {f64, f16, bf16, f32},
torch.searchsorted : {f64, i32, i64, f16, u8, i16, bf16, i8, f32},
torch.cholesky : {f64, f32, c128, c64},
torch.cholesky_inverse : {f64, f32, c128, c64},
Expand Down
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,7 +1900,7 @@ def test_refs_are_in_decomp_table(self, op):
"to_sparse", # Could not run 'aten::to_sparse' with arguments from the 'Meta' backend
"tensor_split", # The tensor has a non-zero number of elements, but its data is not allocated yet
"repeat_interleave", # cannot repeat_interleave a meta tensor without output_size
"segment_reduce.lengths", # Could not run 'aten::segment_reduce' with arguments from the 'Meta' backend.
"_segment_reduce.lengths", # Could not run 'aten::segment_reduce' with arguments from the 'Meta' backend.
"sparse.sampled.addmm", # sparsity not supported
# Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException
"nn.functional.one_hot",
Expand Down Expand Up @@ -1984,7 +1984,7 @@ def test_refs_are_in_decomp_table(self, op):
}

fake_backward_xfails = {xfail(stride_skip) for stride_skip in fake_backward_xfails} | {
xfail("segment_reduce", "lengths"),
xfail("_segment_reduce", "lengths"),
xfail("norm", "nuc"),
xfail("linalg.norm", "subgradients_at_zero"), # can accept vector inputs
skip('nn.functional.ctc_loss'),
Expand Down
4 changes: 2 additions & 2 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ def f(a, b, c, d, e):

fake_tensor_failures = {
# FakeTensor fallback doesn't work
xfail('segment_reduce', 'lengths'),
xfail('_segment_reduce', 'lengths'),
xfail('multinomial'),
xfail('cholesky'),
xfail('cholesky_inverse'),
Expand Down Expand Up @@ -1352,7 +1352,7 @@ def f(a, b, c, d, e):
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('roll', ''), # Tensors of type TensorImpl do not have numel
xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ...
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition
xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition
Expand Down
14 changes: 7 additions & 7 deletions test/test_segment_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _test_common(
segment_reduce_kwargs['lengths'] = lengths
else:
segment_reduce_kwargs['offsets'] = offsets
actual_result = torch.segment_reduce(
actual_result = torch._segment_reduce(
data=data,
reduce=reduction,
**segment_reduce_kwargs
Expand Down Expand Up @@ -108,7 +108,7 @@ def _test_common(
)
self.assertTrue(
gradcheck(
lambda x: torch.segment_reduce(
lambda x: torch._segment_reduce(
data=x,
reduce=reduction,
**segment_reduce_kwargs
Expand Down Expand Up @@ -385,7 +385,7 @@ def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
lengths = torch.diff(indptr, dim=dim)
expected = torch.tensor(test[reduce], dtype=val_dtype, device=device)

actual_result = torch.segment_reduce(
actual_result = torch._segment_reduce(
data=data,
reduce=reduce,
lengths=lengths,
Expand All @@ -395,7 +395,7 @@ def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
self.assertEqual(actual_result, expected)

# test offsets
actual_result = torch.segment_reduce(
actual_result = torch._segment_reduce(
data=data,
reduce=reduce,
offsets=indptr,
Expand All @@ -419,7 +419,7 @@ def fn(x, mode='lengths'):
segment_reduce_kwargs[mode] = lengths
elif mode == 'offsets':
segment_reduce_kwargs[mode] = indptr
return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
return torch._segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))

Expand Down Expand Up @@ -502,13 +502,13 @@ def test_unsafe_flag(self, device, dtype):

# test for error on 1-D lenghts
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
torch._segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)

# test for error on multi-D lengths
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
torch._segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)



Expand Down
6 changes: 6 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,9 @@ def manager_path():
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
# Fixup segment_reduce visibility
_segment_reduce = segment_reduce
del segment_reduce

# Ops not to be exposed in `torch` namespace,
# mostly helper ops.
Expand All @@ -1166,6 +1169,9 @@ def manager_path():
continue
obj = getattr(_C._VariableFunctions, name)
obj.__module__ = 'torch'
# Hide some APIs that should not be public
if name == "segment_reduce":
name = "_" + name
globals()[name] = obj
if not name.startswith("_"):
__all__.append(name)
Expand Down
3 changes: 3 additions & 0 deletions torch/fx/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def _get_qualified_name(func: Callable[..., Any]) -> str:
name = func.__name__
module = _find_module_of_method(func)
module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
# Fixup segment_reduce mismatch
if module == "torch" and name == "segment_reduce":
name = "_" + name
return f'{module}.{name}'

def _format_arg(arg, max_list_len=float('inf')) -> str:
Expand Down
3 changes: 3 additions & 0 deletions torch/jit/_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def register_all(mod):
for name in dir(mod):
v = getattr(mod, name)
if callable(v) and not _is_special_functional_bound_op(v) and v is not torch.no_grad and v is not torch.autocast:
# Fixup inconsistency in segment_reduce
if name == "_segment_reduce":
name = name[1:]
_builtin_ops.append((v, "aten::" + name))
for mod in _modules_containing_builtins:
register_all(mod)
Expand Down
2 changes: 1 addition & 1 deletion torch/masked/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def _sparse_csr_segment_reduction_helper(
)
new_nnz = new_crow_indices[-1]
new_col_indices = col_indices.new_zeros(new_nnz)
new_values = torch.segment_reduce(values, reduce, offsets=crow_indices)
new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
new_shape = [mask_input.size(0), 1]
else:
assert len(dims) == 2
Expand Down
4 changes: 2 additions & 2 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callab
overridable_funcs = collections.defaultdict(list)
index = {}
tested_namespaces = [
("torch", torch, torch.__all__ + dir(torch._C._VariableFunctions)),
("torch", torch, torch.__all__),
("torch.functional", torch.functional, torch.functional.__all__),
("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
Expand Down
6 changes: 4 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17493,7 +17493,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
sample_inputs_func=sample_inputs_scatter_reduce,
),
OpInfo(
'segment_reduce',
'_segment_reduce',
aten_name='segment_reduce',
variant_test_name='lengths',
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
Expand All @@ -17512,7 +17513,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
),
),
OpInfo(
'segment_reduce',
'_segment_reduce',
aten_name='segment_reduce',
variant_test_name='offsets',
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
Expand Down
2 changes: 1 addition & 1 deletion torchgen/static_runtime/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def has_alias(
"_test_warn_in_autograd",
"_test_autograd_multiple_dispatch_view",
"_test_autograd_multiple_dispatch_view_copy",
"segment_reduce",
"_segment_reduce",
"_segment_reduce_backward",
"_fw_primal_copy",
"_make_dual_copy",
Expand Down

0 comments on commit 496c0a2

Please sign in to comment.