Skip to content

Commit

Permalink
[maskedtensor] negative testing (pytorch#85938)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-qi authored and pytorchmergebot committed Sep 30, 2022
1 parent 0a7d8b4 commit b60ad2e
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 34 deletions.
154 changes: 129 additions & 25 deletions test/test_maskedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,14 @@ def _compare_mts(mt1, mt2):
if not _tensors_match(a, b, exact=False):
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")

def _create_random_mask(shape, device):
def _make_tensor_mask(shape, device):
return make_tensor(
shape, device=device, dtype=torch.bool, low=0, high=1, requires_grad=False
)

def _create_random_mask(shape, device):
return torch.randint(0, 2, shape, device=device).bool()

def _generate_sample_data(
device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
):
Expand All @@ -86,7 +89,7 @@ def _generate_sample_data(
inputs = []
for s in shapes:
data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type]
mask = _create_random_mask(s, device)
mask = _make_tensor_mask(s, device)
if layout == torch.sparse_coo:
mask = mask.to_sparse_coo().coalesce()
data = data.sparse_mask(mask).requires_grad_(requires_grad)
Expand All @@ -105,23 +108,63 @@ def _fix_fn_name(fn_name):


class TestBasics(TestCase):
def test_add(self):
data = torch.arange(5.0)
mask = torch.tensor([True, True, False, True, False])
def test_invalid_tensor_inputs(self, device):
data = torch.randn((3, 4), device=device)
mask = _create_random_mask((3, 4), device=device)
mt = masked_tensor(data, mask)

with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
masked_tensor(mt, mask)
with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
masked_tensor(0, mask)
with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
masked_tensor(data, mt)
with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
masked_tensor(data, 0)

def test_diff_layouts(self, device):
data = torch.randn((3, 4), device=device).to_sparse_coo()
mask = _create_random_mask((3, 4), device=device)
with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"):
masked_tensor(data, mask)

def test_diff_dim(self, device):
data = torch.randn((3, 4, 5), device=device)
mask = _create_random_mask((3, 4), device=device)
with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"):
masked_tensor(data, mask)

def test_diff_sizes(self, device):
data = torch.randn((3, 4), device=device)
mask = _create_random_mask((3, 3), device=device)
with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"):
masked_tensor(data, mask)

def test_grad_warning(self, device):
data = torch.randn((3, 4), device=device, requires_grad=True)
mask = _create_random_mask((3, 4), device=device)
msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad."
with self.assertWarnsRegex(UserWarning, msg):
mt = masked_tensor(data, mask)

def test_add(self, device):
data = torch.arange(5.0, device=device)
mask = torch.tensor([True, True, False, True, False], device=device)
m0 = masked_tensor(data, mask)
m1 = masked_tensor(data, ~mask)
with self.assertRaisesRegex(ValueError, "Input masks must match."):
m0 + m1
_compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0]), mask))
_compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask))

def test_softmax(self):
data = torch.randn(3, 4) * 0.1
def test_softmax(self, device):
data = torch.randn((3, 4), device=device) * 0.1
mask = torch.tensor(
[
[True, True, True, False],
[False, True, False, True],
[True, True, False, False],
]
],
device=device
)
mt = masked_tensor(data, mask, requires_grad=True)
masked_res = torch.softmax(mt, -1)
Expand All @@ -133,8 +176,8 @@ def test_softmax(self):
_compare_mt_t(masked_res, tensor_res)
_compare_mt_t(mt.grad, xinf.grad)

def test_where(self):
data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100])
def test_where(self, device):
data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device)
mask = data < 0

mx = masked_tensor(data, mask, requires_grad=True)
Expand All @@ -143,16 +186,16 @@ def test_where(self):
masked_res.sum().backward()

x = data.detach().clone().requires_grad_()
y = torch.ones_like(x, requires_grad=True)
y = torch.ones_like(x, device=device, requires_grad=True)
tensor_res = torch.where(mask, torch.exp(x), y)
tensor_res.sum().backward()

_compare_mt_t(masked_res, tensor_res)
_compare_mt_t(mx.grad, x.grad)
_compare_mt_t(my.grad, y.grad)

def test_to_sparse(self):
for sample in _generate_sample_data():
def test_to_sparse(self, device):
for sample in _generate_sample_data(device=device):
data = sample.input
mask = sample.kwargs["mask"]
mt = masked_tensor(data.clone().detach(), mask, requires_grad=True)
Expand All @@ -164,10 +207,11 @@ def test_to_sparse(self):
_compare_mt_t(sparse_mt, data)
_compare_mt_t(mt.grad, data.grad)

def test_to_dense(self):
def test_to_dense(self, device):
samples = _generate_sample_data(
device=device,
layout=torch.sparse_coo
) + _generate_sample_data(layout=torch.sparse_csr)
) + _generate_sample_data(device=device, layout=torch.sparse_csr)
for sample in samples:
data = sample.input
mask = sample.kwargs["mask"]
Expand All @@ -181,8 +225,8 @@ def test_to_dense(self):
_compare_mt_t(dense_mt, dense_data)
_compare_mt_t(mt.grad.to_dense(), dense_data.grad)

def test_to_dense_and_sparse_coo(self):
for sample in _generate_sample_data(layout=torch.strided):
def test_to_dense_and_sparse_coo(self, device):
for sample in _generate_sample_data(device=device, layout=torch.strided):
data = sample.input
mask = sample.kwargs["mask"]
ms = mask.to_sparse_coo().coalesce()
Expand All @@ -199,8 +243,8 @@ def test_to_dense_and_sparse_coo(self):
_compare_mts(converted, converted2)
_compare_mts(mt.grad, mts.grad.to_dense())

def test_to_dense_and_sparse_csr(self):
for sample in _generate_sample_data(layout=torch.strided):
def test_to_dense_and_sparse_csr(self, device):
for sample in _generate_sample_data(device=device, layout=torch.strided):
data = sample.input
mask = sample.kwargs["mask"]
if data.ndim != 2:
Expand All @@ -219,8 +263,68 @@ def test_to_dense_and_sparse_csr(self):
_compare_mts(converted, converted2)
_compare_mts(mt.grad, mts.grad.to_dense())

def test_contiguous(self):
data = torch.randn(3, 3)
def test_invalid_sparse_layout(self, device):
data = torch.randn((3, 4), device=device).to_sparse_csc()
mask = _create_random_mask((3, 4), device=device).to_sparse_csc()
with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"):
masked_tensor(data, mask)

def test_invalid_sparse_coo_values(self, device):
v = torch.tensor([3, 4, 5], dtype=torch.float32)
i1 = torch.tensor([[0, 1, 1], [2, 0, 2]])
i2 = torch.tensor([[0, 1, 1], [2, 1, 2]])

t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device)
mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device)

msg = "data and mask are both sparse COO tensors but do not have the same indices."
with self.assertRaisesRegex(ValueError, msg):
masked_tensor(t, mask)

def test_invalid_sparse_csr_values(self, device):
crow_indices1 = [0, 2, 3]
crow_indices2 = [0, 1, 3]
col_indices1 = [0, 1, 2]
col_indices2 = [1, 2, 3]

values = [2, 3, 4]
mask_values = [True, True, True]

t1 = torch.sparse_csr_tensor(
torch.tensor(crow_indices1, dtype=torch.int64),
torch.tensor(col_indices1, dtype=torch.int64),
torch.tensor(values),
size=(2, 4)
)
mask1 = torch.sparse_csr_tensor(
torch.tensor(crow_indices2, dtype=torch.int64),
torch.tensor(col_indices1, dtype=torch.int64),
torch.tensor(mask_values),
dtype=torch.bool,
size=(2, 4),
)
t2 = torch.sparse_csr_tensor(
torch.tensor(crow_indices2, dtype=torch.int64),
torch.tensor(col_indices1, dtype=torch.int64),
torch.tensor(values),
size=(2, 4),
)
mask2 = torch.sparse_csr_tensor(
torch.tensor(crow_indices2, dtype=torch.int64),
torch.tensor(col_indices2, dtype=torch.int64),
torch.tensor(mask_values),
dtype=torch.bool,
size=(2, 4),
)

msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices."
with self.assertRaisesRegex(ValueError, msg):
masked_tensor(t1, mask1)
with self.assertRaisesRegex(ValueError, msg):
masked_tensor(t2, mask2)

def test_contiguous(self, device):
data = torch.randn((3, 3), device=device)

contiguous_data = data.clone()
mask1 = (contiguous_data > 0).bool()
Expand Down Expand Up @@ -699,7 +803,7 @@ def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided):
input = sample.input
sample_args, sample_kwargs = sample.args, sample.kwargs
mask = (
_create_random_mask(input.shape, device)
_make_tensor_mask(input.shape, device)
if "mask" not in sample_kwargs
else sample_kwargs.pop("mask")
)
Expand Down Expand Up @@ -745,7 +849,7 @@ def _test_reduction_equality(self, device, dtype, op, layout=torch.strided):
if input.dim() == 0 or input.numel() == 0:
continue

mask = _create_random_mask(input.shape, device)
mask = _make_tensor_mask(input.shape, device)

if torch.count_nonzero(mask) == 0:
continue
Expand Down Expand Up @@ -799,7 +903,7 @@ def test_reduction_all(self, device, dtype, op, layout):
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)

instantiate_parametrized_tests(TestBasics)
instantiate_device_type_tests(TestBasics, globals(), only_for=only_for)
instantiate_parametrized_tests(TestUnary)
instantiate_parametrized_tests(TestBinary)
instantiate_parametrized_tests(TestReductions)
Expand Down
6 changes: 2 additions & 4 deletions torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def _maybe_get_mask(a):
class MaskedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, mask, requires_grad=False):
if not torch.is_tensor(data):
if is_masked_tensor(data) or not torch.is_tensor(data):
raise TypeError("data must be a Tensor")
if not torch.is_tensor(mask):
if is_masked_tensor(mask) or not torch.is_tensor(mask):
raise TypeError("mask must be a Tensor")
# Use a Tensor that of the give size for the wrapper.
kwargs = {}
Expand Down Expand Up @@ -208,8 +208,6 @@ def _validate_members(self):
raise ValueError("data.dim() must equal mask.dim()")
if data.size() != mask.size():
raise ValueError("data.size() must equal mask.size()")
if mask.requires_grad:
raise ValueError("mask cannot have requires_grad=True")

def __init__(self, data, mask, requires_grad=False):
self._preprocess_data(data, mask)
Expand Down
6 changes: 1 addition & 5 deletions torch/masked/maskedtensor/creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

from .core import MaskedTensor, is_masked_tensor
from .core import MaskedTensor

__all__ = [
"as_masked_tensor",
Expand All @@ -15,11 +15,7 @@
"""

def masked_tensor(data, mask, requires_grad=False):
assert not is_masked_tensor(data)
assert not is_masked_tensor(mask)
return MaskedTensor(data, mask, requires_grad)

def as_masked_tensor(data, mask):
assert not is_masked_tensor(data)
assert not is_masked_tensor(mask)
return MaskedTensor._from_values(data, mask)

0 comments on commit b60ad2e

Please sign in to comment.