Skip to content

Commit

Permalink
[maskedtensor] port torch/_masked into torch/masked (pytorch#85515)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-qi authored and pytorchmergebot committed Sep 26, 2022
1 parent 9026194 commit 686555b
Show file tree
Hide file tree
Showing 19 changed files with 182 additions and 146 deletions.
24 changes: 12 additions & 12 deletions functorch/op_analysis/public_api
Original file line number Diff line number Diff line change
Expand Up @@ -602,18 +602,18 @@ is_complex
is_floating_point
is_signed
sum_to_size
_masked.amax
_masked.amin
_masked.log_softmax
_masked.mean
_masked.norm
_masked.normalize
_masked.prod
_masked.softmax
_masked.softmin
_masked.std
_masked.sum
_masked.var
masked.amax
masked.amin
masked.log_softmax
masked.mean
masked.norm
masked.normalize
masked.prod
masked.softmax
masked.softmin
masked.std
masked.sum
masked.var
masked_fill
masked_scatter
masked_select
Expand Down
20 changes: 10 additions & 10 deletions functorch/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,14 @@ def f(inp, *args, **kwargs):
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
# Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed)
# The failure occurred for item [0]
xfail('_masked.prod')
xfail('masked.prod')
}))
@opsToleranceOverride('TestOperators', 'test_vjpvjp', (
tol1('nn.functional.conv_transpose3d',
{torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type='cuda'),
tol1('prod',
{torch.float32: tol(atol=2e-05, rtol=1e-04)}),
tol1('_masked.cumprod',
tol1('masked.cumprod',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1('cumprod',
{torch.float32: tol(atol=5e-04, rtol=5e-04)}),
Expand Down Expand Up @@ -591,8 +591,8 @@ def fn(inp, *args, **kwargs):
xfail("take"), # vmap: inplace into a regular tensor
xfail("to"), # rank 4 tensor for channels_last
xfail("view_as_complex"), # RuntimeError: Tensor must have a last dimension with stride 1
xfail("_masked.softmax", device_type='cuda'), # Mismatch in values!
xfail("_masked.softmin", device_type='cuda'), # Mismatch in values!
xfail("masked.softmax", device_type='cuda'), # Mismatch in values!
xfail("masked.softmin", device_type='cuda'), # Mismatch in values!
# got a batched tensor as input while the running_mean or running_var,
# which will be updated in place, were not batched.
xfail("nn.functional.batch_norm", 'without_cudnn'),
Expand Down Expand Up @@ -763,7 +763,7 @@ def test_vmapvjp(self, device, dtype, op):
# ---------------------------- BUGS ------------------------------------
# The following are bugs that we should fix
skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda
xfail('_masked.mean'), # silent incorrectness (nan difference)
xfail('masked.mean'), # silent incorrectness (nan difference)

xfail('nn.functional.soft_margin_loss', ''), # soft_margin_loss_backward does not support forward-ad
xfail('tensor_split'), # data_ptr composite compliance
Expand Down Expand Up @@ -836,7 +836,7 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('masked_fill'),
xfail('copysign'),
xfail('complex'),
skip('_masked.mean'), # ???
skip('masked.mean'), # ???
xfail('masked_scatter'),
xfail('index_fill'),
xfail('put'),
Expand Down Expand Up @@ -874,7 +874,7 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('linalg.lu_solve', ''),
xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''),
xfail('_masked.cumprod', ''),
xfail('masked.cumprod', ''),
xfail('linalg.vecdot', ''),
}))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
Expand Down Expand Up @@ -982,7 +982,7 @@ def test():
xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''),
xfail('segment_reduce', 'offsets'),
xfail('_masked.cumprod', ''),
xfail('masked.cumprod', ''),
xfail('linalg.vecdot', ''),
xfail('segment_reduce', 'lengths'),
xfail('sparse.sampled_addmm', ''),
Expand Down Expand Up @@ -1148,9 +1148,9 @@ def get_vjp(cotangents, *primals):
xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce
}))
@opsToleranceOverride('TestOperators', 'test_jvpvjp', (
tol1('_masked.prod',
tol1('masked.prod',
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)}),
tol1('_masked.cumprod',
tol1('masked.cumprod',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1('cumprod',
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'),
Expand Down
16 changes: 8 additions & 8 deletions test/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
args0 = args

# dimensions along which the reduction operation is applied:
dim_ = torch._masked._canonical_dim(dim, input.ndim)
dim_ = torch.masked._canonical_dim(dim, input.ndim)
# slices in product(*ranges) define all elementary slices:
ranges: List[Any] = []
# shape of output for the keepdim=True case:
Expand All @@ -116,7 +116,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
if mask is None:
inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
else:
inpmask = torch._masked._input_mask(input, mask=mask)
inpmask = torch.masked._input_mask(input, mask=mask)
for s in itertools.product(*ranges):
# data of an elementary slice is 1D sequence and has only
# masked-in elements:
Expand Down Expand Up @@ -148,7 +148,7 @@ def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
if mask is None:
inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
else:
inpmask = torch._masked._input_mask(input, mask=mask)
inpmask = torch.masked._input_mask(input, mask=mask)
dim_ = dim % input.ndim
left_ranges = tuple(map(range, input.shape[:dim_]))
right_ranges = tuple(map(range, input.shape[dim_ + 1:]))
Expand All @@ -169,7 +169,7 @@ def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)),
)

masked_ops = [op for op in op_db if op.name.startswith('_masked.')]
masked_ops = [op for op in op_db if op.name.startswith('masked.')]
masked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions]
masked_ops_with_non_strided_support = [op for op in masked_ops if op.supports_sparse or op.supports_sparse_csr]

Expand Down Expand Up @@ -287,7 +287,7 @@ def test_reference_masked(self, device, dtype, op):
if t_kwargs.get('mask') is None:
outmask = None
else:
outmask = torch._masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
outmask = torch.masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
self.assertEqualMasked(actual, expected, outmask)

@mask_layouts()
Expand All @@ -309,7 +309,7 @@ def test_mask_layout(self, layout, device, dtype, op, sample_inputs):
if r_kwargs.get('mask') is None:
outmask = None
else:
outmask = torch._masked._output_mask(op.op, r_inp, *r_args, **r_kwargs)
outmask = torch.masked._output_mask(op.op, r_inp, *r_args, **r_kwargs)
expected = op.op(r_inp, *r_args, **r_kwargs)
self.assertEqualMasked(actual, expected, outmask)

Expand Down Expand Up @@ -393,8 +393,8 @@ def set_values(sparse, index, value):
tmp = to_sparse(tmp)


sparse = torch._masked._where(mask, input,
torch.tensor(fill_value, dtype=input.dtype, device=input.device))
sparse = torch.masked._where(mask, input,
torch.tensor(fill_value, dtype=input.dtype, device=input.device))

if tmp.layout == torch.sparse_coo:
expected_sparse = torch.sparse_coo_tensor(
Expand Down
3 changes: 1 addition & 2 deletions test/test_maskedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
unary_ufuncs,
)

from torch._masked import _combine_input_and_mask
from torch.masked import as_masked_tensor, masked_tensor
from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask
from torch.masked.maskedtensor.core import _masks_match, _tensors_match
from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES
from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES
Expand Down
42 changes: 21 additions & 21 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6428,16 +6428,16 @@ class TestConsistency(TestCase):
'__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'],
'__rpow__': ['f16'],
'__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'],
'_masked.argmax': ['i16', 'i64', 'u8'],
'_masked.argmin': ['i16', 'i64', 'u8'],
'_masked.log_softmax': ['f32'],
'_masked.logaddexp': ['f32'],
'_masked.norm': ['f16', 'f32'],
'_masked.normalize': ['f16', 'f32'],
'_masked.softmax': ['f32'],
'_masked.softmin': ['f32'],
'_masked.std': ['f32'],
'_masked.var': ['f32'],
'masked.argmax': ['i16', 'i64', 'u8'],
'masked.argmin': ['i16', 'i64', 'u8'],
'masked.log_softmax': ['f32'],
'masked.logaddexp': ['f32'],
'masked.norm': ['f16', 'f32'],
'masked.normalize': ['f16', 'f32'],
'masked.softmax': ['f32'],
'masked.softmin': ['f32'],
'masked.std': ['f32'],
'masked.var': ['f32'],
'abs': ['f16', 'f32', 'i16', 'i32', 'u8'],
'acos': ['f32', 'i16', 'i32', 'u8'],
'acosh': ['f32', 'i16', 'i32', 'u8'],
Expand Down Expand Up @@ -6666,12 +6666,12 @@ class TestConsistency(TestCase):
'__rdiv__': ['f16', 'f32'],
'__rmatmul__': ['f32'],
'__rmul__': ['f16', 'f32'],
'_masked.log_softmax': ['f32'],
'_masked.logaddexp': ['f32'],
'_masked.softmax': ['f32'],
'_masked.softmin': ['f32'],
'_masked.std': ['f32'],
'_masked.var': ['f32'],
'masked.log_softmax': ['f32'],
'masked.logaddexp': ['f32'],
'masked.softmax': ['f32'],
'masked.softmin': ['f32'],
'masked.std': ['f32'],
'masked.var': ['f32'],
'abs': ['f16', 'f32'],
'acos': ['f32'],
'acosh': ['f32'],
Expand Down Expand Up @@ -6837,9 +6837,9 @@ class TestConsistency(TestCase):
# Functions that hang
'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool],
# + forward when requires_grad=True or running backward
'_masked.mean': [torch.bool, torch.float16],
'_masked.prod': [torch.bool],
'_masked.sum': [torch.bool],
'masked.mean': [torch.bool, torch.float16],
'masked.prod': [torch.bool],
'masked.sum': [torch.bool],

# Functions that hard crash
'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64],
Expand All @@ -6851,8 +6851,8 @@ class TestConsistency(TestCase):
'index_select': [torch.float16],
'nn.functional.embedding': [torch.float32, torch.float16],
'__rpow__': [torch.int64],
'_masked.std': [torch.int32],
'_masked.var': [torch.int32],
'masked.std': [torch.int32],
'masked.var': [torch.int32],
'as_strided_scatter': [torch.uint8],
'atan2': [torch.int64],
'bfloat16': None,
Expand Down
36 changes: 18 additions & 18 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,25 +1008,25 @@ def f(a, b, c, d, e):
xfail('polar'),
xfail('linalg.eig'),
xfail('linalg.eigvals'),
skip('_masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('_masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
xfail('_masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
xfail('_masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
xfail('_masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
xfail('_masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
xfail('_masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
xfail('_masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('_masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('_masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition
xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ...
xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition
xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition
xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d...
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition
Expand Down
2 changes: 1 addition & 1 deletion test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3361,7 +3361,7 @@ def to_numpy(input):
expected = op.ref(to_numpy(t), *sample_input.args,
**dict(
# `identity` is mapped to numpy reduction `initial` argument
identity=torch._masked._reduction_identity(op.name, t),
identity=torch.masked._reduction_identity(op.name, t),
**sample_input.kwargs))

# Workaround https://github.com/pytorch/pytorch/issues/66556
Expand Down
6 changes: 3 additions & 3 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3985,12 +3985,12 @@ def test_future_empty_dim(self, device, dtype, op):
torch.testing._internal.common_methods_invocations._generate_reduction_kwargs
is made to generate samples with `dim=()` for non-scalar
inputs. With this and after gh-29137 is resolved, this test
can be deleted. See also `torch._masked._canonical_dim`
can be deleted. See also `torch.masked._canonical_dim`
implementation about changing the `dim=()` behavior.
"""

samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
op_name = op.name.replace('_masked.', '')
op_name = op.name.replace('masked.', '')
for sample_input in samples:
if sample_input.kwargs.get('dim') != 0:
continue
Expand All @@ -4002,7 +4002,7 @@ def test_future_empty_dim(self, device, dtype, op):
if mask is None and op_name in {'prod', 'amax', 'amin'}:
# FIXME: for now reductions with non-zero reduction identity and
# unspecified mask are not supported for sparse COO
# tensors, see torch._masked.prod implementation
# tensors, see torch.masked.prod implementation
# for details.
continue
sparse_op_kwargs = dict(sample_input_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,10 @@ def test_consistency(self, layout, device, dtype, op):

# FIXME: remove in followup once integer support is landed for segment_reduce
if (layout == torch.sparse_csr and not dtype.is_floating_point
and op.name in ('_masked.mean', '_masked.amax', '_masked.amin')):
and op.name in ('masked.mean', 'masked.amax', 'masked.amin')):
self.skipTest(f"{op.name} does not support input with {layout} layout")

require_mask = isinstance(op, ReductionOpInfo) and '_masked.' in op.name
require_mask = isinstance(op, ReductionOpInfo) and 'masked.' in op.name
if require_mask and layout in {torch.sparse_bsr, torch.sparse_bsc}:
self.skipTest(f"{op.name} does not support input with {layout} layout")

Expand Down Expand Up @@ -627,7 +627,7 @@ def test_consistency(self, layout, device, dtype, op):
assert torch.is_tensor(output)
strided_output = output.to_dense()
if require_mask:
output_mask = torch._masked._output_mask(op.op, sample.input, **sample.kwargs)
output_mask = torch.masked._output_mask(op.op, sample.input, **sample.kwargs)
expected.masked_fill_(~output_mask, 0)
self.assertEqual(strided_output, expected)
count += 1
Expand Down
12 changes: 6 additions & 6 deletions tools/update_masked_docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This script updates the file torch/_masked/_docs.py that contains
"""This script updates the file torch/masked/_docs.py that contains
the generated doc-strings for various masked operations. The update
should be triggered whenever a new masked operation is introduced to
torch._masked package. Running the script requires that torch package
torch.masked package. Running the script requires that torch package
is functional.
"""

Expand All @@ -10,7 +10,7 @@

def main() -> None:

target = os.path.join("torch", "_masked", "_docs.py")
target = os.path.join("torch", "masked", "_docs.py")

try:
import torch
Expand Down Expand Up @@ -40,9 +40,9 @@ def main() -> None:
"""
)

for func_name in sorted(torch._masked.__all__):
func = getattr(torch._masked, func_name)
func_doc = torch._masked._generate_docstring(func)
for func_name in sorted(torch.masked._ops.__all__):
func = getattr(torch.masked._ops, func_name)
func_doc = torch.masked._generate_docstring(func)
_new_content.append(f'{func_name}_docstring = """{func_doc}"""\n')

new_content = "\n".join(_new_content)
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def compiled_with_cxx11_abi():
# Import experimental masked operations support. See
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
# information.
from . import _masked
from . import masked

# Import removed ops with error message about removal
from ._linalg_utils import ( # type: ignore[misc]
Expand Down
Loading

0 comments on commit 686555b

Please sign in to comment.