Skip to content

Commit

Permalink
[maskedtensor] use masked_softmax for forward/backward instead of reg…
Browse files Browse the repository at this point in the history
  • Loading branch information
george-qi authored and pytorchmergebot committed Sep 30, 2022
1 parent 1c97084 commit a4d1034
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
10 changes: 5 additions & 5 deletions test/test_maskedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.masked.maskedtensor.reductions import REDUCE_NAMES


def _compare_mt_t(mt_result, t_result):
def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05):
mask = mt_result.get_mask()
mt_result_data = mt_result.get_data()
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
Expand All @@ -35,10 +35,10 @@ def _compare_mt_t(mt_result, t_result):
mt_result_data = mt_result_data.to_dense()
a = mt_result_data.detach().masked_fill_(~mask, 0)
b = t_result.detach().masked_fill_(~mask, 0)
if not _tensors_match(a, b, exact=False):
if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
raise ValueError("The data in MaskedTensor a and Tensor b do not match")

def _compare_mts(mt1, mt2):
def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08):
mt_data1 = mt1.get_data()
mt_data2 = mt2.get_data()
if mt_data1.layout != mt_data2.layout:
Expand All @@ -61,7 +61,7 @@ def _compare_mts(mt1, mt2):
a = mt_data1.detach().masked_fill_(~mask, 0)
b = mt_data2.detach().masked_fill_(~mask, 0)

if not _tensors_match(a, b, exact=False):
if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")

def _make_tensor_mask(shape, device):
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_softmax(self, device):
tensor_res.sum().backward()

_compare_mt_t(masked_res, tensor_res)
_compare_mt_t(mt.grad, xinf.grad)
_compare_mt_t(mt.grad, xinf.grad, atol=1e-06)

def test_where(self, device):
data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device)
Expand Down
12 changes: 7 additions & 5 deletions torch/masked/maskedtensor/_ops_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,7 @@ def _softmax(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0)
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
input_data = data.masked_fill(~mask, float("-inf"))
result_data = func(input_data, args[1], args[2])
result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
return MaskedTensor(result_data, mask)


Expand All @@ -332,9 +331,12 @@ def _softmax_backward_data(func, *args, **kwargs):
if is_masked_tensor(grad) and is_masked_tensor(output):
if not _masks_match(grad, output):
raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match")
grad_data = _get_data(grad).masked_fill(~_maybe_get_mask(grad), 1)
output_data = _get_data(output).masked_fill(~_maybe_get_mask(output), 0)
new_grad_data = func(grad_data, output_data, dim, input_dtype)
new_grad_data = torch.ops.aten._masked_softmax_backward(
_get_data(grad),
_get_data(output),
~_maybe_get_mask(grad),
dim
)
res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
return res
else:
Expand Down
4 changes: 2 additions & 2 deletions torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def is_masked_tensor(a):
return isinstance(a, MaskedTensor)


def _tensors_match(a, b, exact=True):
def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
if is_masked_tensor(a) or is_masked_tensor(b):
raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
if a.layout != b.layout:
Expand All @@ -51,7 +51,7 @@ def _tensors_match(a, b, exact=True):
)
if exact:
return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
return (a.dim() == b.dim()) and torch.allclose(a, b)
return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)


def _masks_match(a, b):
Expand Down

0 comments on commit a4d1034

Please sign in to comment.