Skip to content

Commit

Permalink
[maskedtensor] adding reductions (pytorch#82839)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-qi authored and pytorchmergebot committed Sep 6, 2022
1 parent f125bd2 commit 5e9c26c
Show file tree
Hide file tree
Showing 11 changed files with 882 additions and 136 deletions.
332 changes: 325 additions & 7 deletions test/test_maskedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
SampleInput,
)

from torch.masked import masked_tensor
from torch.masked import MaskedTensor, masked_bmm
from torch.masked.maskedtensor.core import _masks_match, _tensors_match
from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS

from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS


Expand Down Expand Up @@ -126,7 +125,7 @@ def _get_sample_kwargs(self, fn_name):

def _get_sample_args(self, fn_name, data, mask):
fn_name = _fix_fn_name(fn_name)
mt = masked_tensor(data, mask)
mt = MaskedTensor(data, mask)
t_args = [data]
mt_args = [mt]
if fn_name in ["pow"]:
Expand Down Expand Up @@ -185,8 +184,8 @@ def _yield_sample_args(self, fn_name, data0, data1, mask):
while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor)
"""
fn_name = _fix_fn_name(fn_name)
mt0 = masked_tensor(data0, mask)
mt1 = masked_tensor(data1, mask)
mt0 = MaskedTensor(data0, mask)
mt1 = MaskedTensor(data1, mask)

t_args = [data0, data1]
mt_args = [mt0, mt1]
Expand Down Expand Up @@ -227,8 +226,8 @@ def test_masks_match(self, fn_name):
data0, data1, mask = self._get_test_data(fn_name)
mask0 = mask
mask1 = torch.rand(mask.size()) > 0.5
mt0 = masked_tensor(data0, mask0)
mt1 = masked_tensor(data1, mask1)
mt0 = MaskedTensor(data0, mask0)
mt1 = MaskedTensor(data1, mask1)
try:
fn(mt0, mt1)
raise AssertionError()
Expand All @@ -238,8 +237,327 @@ def test_masks_match(self, fn_name):
== str(e)
)

class TestReductions(TestCase):
def test_max_not_implemented(self):
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m)
with self.assertRaisesRegex(TypeError, "no implementation found for 'torch.ops.aten.max'"):
mt.max()

def test_sum(self):
d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]])
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
mt = MaskedTensor(d, m)
_compare_mts(MaskedTensor(torch.tensor(17.0), torch.tensor(True)), mt.sum())
_compare_mts(
MaskedTensor(
torch.tensor([0.0, 4.0, 1.0, 13]),
torch.tensor([True, True, False, True]),
),
mt.sum(dim=0),
)

def test_sum_grad(self):
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=True)
mt.sum().backward()
_compare_mts(mt.grad, MaskedTensor(torch.tensor(1.0).expand_as(m), m))

def test_mean(self):
d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]])
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
mt = MaskedTensor(d, m)
_compare_mts(MaskedTensor(torch.tensor(2.5), torch.tensor(True)), mt.mean())
_compare_mts(
MaskedTensor(
torch.tensor([0.0, 4.0, 1.0, 3]),
torch.tensor([True, True, False, True]),
),
mt.mean(dim=0),
)

"""
The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of
the two different ways of constructing MaskedTensors:
MaskedTensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf
MaskedTensor.from_values(data, mask) -- differentiable constructor
Like torch.tensor(data), MaskedTensor(data, mask) will provide a UserWarning if data.requires_grad=True
MaskedTensor.from_values does not take in requires_grad -- it just takes on the requires_grad from data
Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations
Assuming mt.mean().backward() is run after each constructor:
Case 1a:
values.requires_grad = True
mt = MaskedTensor(values, mask, requires_grad=True)
yields
- Provide a UserWarning because values.requires_grad=True
- values.grad = None
- mt.grad is a MaskedTensor with the correct gradient
Case 1b:
values.requires_grad = False
mt = MaskedTensor(values, mask, requires_grad=True)
yields
- values.grad = None
- mt.grad is a MaskedTensor with the correct gradient
Case 2a/2b:
values.requires_grad = True/False
mt = MaskedTensor(values, mask, requires_grad=False)
will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
as expected. When values.requires_grad=True, we will also get a UserWarning
Case 3a:
values.requires_grad = True
mt = MaskedTensor.from_values(values, mask)
yields
- values.grad is a MaskedTensor with the correct gradient
- mt.grad is None and gives a UserWarning that
"The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
Case 3b:
values.requires_grad = False
mt = MaskedTensor.from_values(values, mask)
will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
as expected.
"""
def test_mean_grad_case_1a(self):
""" values.requires_grad = True
mt = MaskedTensor(values, mask, requires_grad=True)
"""
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
m = torch.tensor([[True, False, False], [False, True, False]])
with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
mt = MaskedTensor(d, m, requires_grad=True)
mt.mean().backward()
self.assertIsNone(d.grad)
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))

def test_mean_grad_case_1b(self):
""" values.requires_grad = False
mt = MaskedTensor(values, mask, requires_grad=True)
"""
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=True)
mt.mean().backward()
self.assertIsNone(d.grad)
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))

def test_mean_grad_case_1c(self):
""" values.requires_grad = True
mt = MaskedTensor(values, mask, requires_grad=False)
"""
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
m = torch.tensor([[True, False, False], [False, True, False]])
with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
mt = MaskedTensor(d, m, requires_grad=False)
result = mt.mean()
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
with self.assertRaisesRegex(RuntimeError, msg):
result.backward()


def test_mean_grad_case_1d(self):
""" values.requires_grad = False
mt = MaskedTensor(values, mask, requires_grad=False)
"""
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=False)
result = mt.mean()
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
with self.assertRaisesRegex(RuntimeError, msg):
result.backward()

def test_mean_grad_case_1e(self):
""" values.requires_grad = True
mt = MaskedTensor.from_values(values, mask)
"""
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor.from_values(d, m)
mt.mean().backward()
_compare_mts(d.grad, MaskedTensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
with self.assertWarnsRegex(UserWarning, msg):
self.assertIsNone(mt.grad)

def test_mean_grad_case_1f(self):
""" values.requires_grad = False
mt = MaskedTensor.from_values(values, mask)
"""
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor.from_values(d, m)
result = mt.mean()
msg = "element 0 of tensors does not require grad and does not have a grad_fn"
with self.assertRaisesRegex(RuntimeError, msg):
result.backward()

def test_mean_dim_grad(self):
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, True, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=True)
mt.mean(1).sum().backward()
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m))

def test_amax(self):
d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
mt = MaskedTensor(d, m)
_compare_mts(MaskedTensor(torch.tensor(3.0), torch.tensor(True)), mt.amax())
_compare_mts(
MaskedTensor(
torch.tensor([0.0, -4.0, 1.0, 3]),
torch.tensor([True, True, False, True]),
),
mt.amax(dim=0),
)

def test_amax_grad(self):
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=True)
mt.amax().backward()
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m))

def test_amin(self):
d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
mt = MaskedTensor(d, m)
_compare_mts(MaskedTensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin())
_compare_mts(
MaskedTensor(
torch.tensor([0.0, -4.0, 1.0, -3]),
torch.tensor([True, True, False, True]),
),
mt.amin(dim=0),
)

def test_amin_grad(self):
d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=True)
mt.amin().backward()
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m))

def test_prod(self):
d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]])
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
mt = MaskedTensor(d, m)
_compare_mts(MaskedTensor(torch.tensor(0.0), torch.tensor(True)), mt.prod())
_compare_mts(
MaskedTensor(
torch.tensor([0.0, 4.0, 1.0, 0.0]),
torch.tensor([True, True, False, True]),
),
mt.prod(dim=0),
)

def test_prod_grad(self):
d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]])
m = torch.tensor([[True, False, False], [False, True, False]])
mt = MaskedTensor(d, m, requires_grad=True)
mt.prod().backward()
_compare_mts(mt.grad, MaskedTensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m))

def test_all(self):
d = torch.tensor([[True, True, False, False], [False, True, True, True]])
m = torch.tensor([[True, False, False, True], [False, True, False, True]])
mt = MaskedTensor(d, m)
_compare_mts(MaskedTensor(torch.tensor(False), torch.tensor(True)), mt.all())
_compare_mts(
MaskedTensor(
torch.tensor([True, True, True, False]),
torch.tensor([True, True, False, True]),
),
mt.all(dim=0),
)

m = torch.tensor([[True, False, True, False], [False, True, False, False]])
mt = MaskedTensor(d, m)
_compare_mts(
MaskedTensor(
torch.tensor([True, True, False, True]),
torch.tensor([True, True, True, False]),
),
mt.all(dim=0),
)

def test_grad_dtype(self):
d = torch.tensor([[True, True, False], [False, True, True]])
m = torch.tensor([[True, False, False], [False, True, False]])
msg = "Only Tensors of floating point and complex dtype can require gradients"
with self.assertRaisesRegex(RuntimeError, msg):
MaskedTensor(d, m, requires_grad=True)

class TestMatMul(TestCase):
def test_bmm(self):
x = torch.rand(3, 2, 1)
key_padding_mask = torch.tensor(
[
[False, False, False],
[False, True, True],
]
)
x_mt = MaskedTensor(x, ~(key_padding_mask.transpose(0, 1).unsqueeze(-1)))
x = x.masked_fill(~x_mt.get_mask(), 0)
attn_2 = torch.bmm(x, x.transpose(-2, -1))
attn_3 = torch.bmm(x_mt, x_mt.transpose(-2, -1))
self.assertEqual(attn_3.get_data().masked_fill(~attn_3.get_mask(), 0), attn_2) # type: ignore[attr-defined]

def test_masked_bmm(self):
key_padding_mask = torch.tensor(
[
[False, False, False, True],
[False, True, True, True],
[False, True, False, True],
]
)
x = torch.arange(4 * 3 * 2).reshape(4, 3, 2).float()
x_mt = MaskedTensor(
x,
~(key_padding_mask.transpose(0, 1).unsqueeze(-1).expand_as(x)),
requires_grad=True,
)
attn_mask_bool = torch.tensor(
[
[False, True, True],
[False, False, True],
[True, False, False],
]
)
attn_mask = attn_mask_bool.float().masked_fill_(attn_mask_bool, float("-inf"))
v = masked_bmm(x, x_mt.transpose(1, 2), attn_mask)
v.sum().backward()

def test_linear(self):
x = torch.arange(4 * 3 * 2).reshape(4, 3, 2)
w_x = torch.arange(10).reshape(5, 2) + x.amax()
linear = torch.nn.functional.linear
key_padding_mask = torch.tensor(
[
[False, False, False, True],
[False, True, True, True],
[False, True, False, True],
]
)
x_mt = MaskedTensor(
x, ~(key_padding_mask.transpose(0, 1).unsqueeze(-1).expand_as(x))
)

instantiate_parametrized_tests(TestUnary)
instantiate_parametrized_tests(TestBinary)
instantiate_parametrized_tests(TestReductions)
instantiate_parametrized_tests(TestMatMul)

if __name__ == '__main__':
run_tests()
Loading

0 comments on commit 5e9c26c

Please sign in to comment.