Skip to content

Commit

Permalink
[maskedtensor] add basic tests and unary/binary/reduction tests from …
Browse files Browse the repository at this point in the history
…common_method_invocations (pytorch#82841)

Decided offline on the invariant that:

`masked_tensor` calls `MaskedTensor()`, which is analogous to `torch.tensor`
`as_masked_tensor` calls `MaskedTensor._from_values()`, which is analogous to `torch.as_tensor`
Pull Request resolved: pytorch#82841
Approved by: https://github.com/cpuhrsch, https://github.com/bhosmer
  • Loading branch information
george-qi authored and pytorchmergebot committed Sep 22, 2022
1 parent 2bc8216 commit 0c46e3e
Show file tree
Hide file tree
Showing 10 changed files with 449 additions and 432 deletions.
484 changes: 365 additions & 119 deletions test/test_maskedtensor.py

Large diffs are not rendered by default.

72 changes: 36 additions & 36 deletions torch/_masked/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from torch import Tensor
from torch.masked import is_masked_tensor, MaskedTensor
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor
from . import _docs

if TYPE_CHECKING:
Expand Down Expand Up @@ -1020,7 +1020,7 @@ def backward(ctx, grad_output):
grad_data = (
grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
)
result = MaskedTensor.from_values(grad_data, mask)
result = as_masked_tensor(grad_data, mask)
return result, None

return (
Expand Down Expand Up @@ -1067,19 +1067,19 @@ def sum(
dtype = torch.int64
dim_ = _canonical_dim(dim, input.ndim)
mask_input = _combine_input_and_mask(sum, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
elif input.layout == torch.sparse_coo:
elif mask_input.layout == torch.sparse_coo:
return _sparse_coo_scatter_reduction_helper(
torch.sum, mask_input, dim_, bool(keepdim), dtype
)
elif input.layout == torch.sparse_csr:
elif mask_input.layout == torch.sparse_csr:
return torch._sparse_csr_sum(
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
)
else:
raise ValueError(
f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)"
f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
)


Expand Down Expand Up @@ -1120,14 +1120,14 @@ def prod(
dtype = torch.int64
dim_ = _canonical_dim(dim, input.ndim)
mask_input = _combine_input_and_mask(prod, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
# Workaround https://github.com/pytorch/pytorch/issues/56586
result = mask_input
result = result.to(dtype=dtype)
for d in reversed(dim_):
result = result.prod(dim=d, keepdim=bool(keepdim))
return result
elif input.layout == torch.sparse_coo:
elif mask_input.layout == torch.sparse_coo:
if mask is None:
# See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
raise ValueError(
Expand All @@ -1136,7 +1136,7 @@ def prod(
return _sparse_coo_scatter_reduction_helper(
torch.prod, mask_input, dim_, bool(keepdim), dtype
)
elif input.layout == torch.sparse_csr:
elif mask_input.layout == torch.sparse_csr:
if mask is None:
# mask is None corresponds to all-True mask. The
# unspecified elements in the CSR tensor correspond to
Expand All @@ -1156,7 +1156,7 @@ def prod(
)
else:
raise ValueError(
f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)"
f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1172,11 +1172,11 @@ def cumsum(
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
mask_input = _combine_input_and_mask(sum, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
else:
raise ValueError(
f"masked cumsum expects strided tensor (got {input.layout} tensor)"
f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1192,11 +1192,11 @@ def cumprod(
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
mask_input = _combine_input_and_mask(prod, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
else:
raise ValueError(
f"masked cumprod expects strided tensor (got {input.layout} tensor)"
f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
)


Expand Down Expand Up @@ -1224,9 +1224,9 @@ def amax(

mask_input = _combine_input_and_mask(amax, input, mask)
dim_ = _canonical_dim(dim, mask_input.ndim)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
elif input.layout == torch.sparse_coo:
elif mask_input.layout == torch.sparse_coo:
if mask is None:
# See comment in the sparse_csr branch of prod, a similar issue arises here
# where unspecified elements along a dimension may need to be reduced with the result
Expand All @@ -1236,7 +1236,7 @@ def amax(
return _sparse_coo_scatter_reduction_helper(
torch.amax, mask_input, dim_, bool(keepdim), dtype
)
elif input.layout == torch.sparse_csr:
elif mask_input.layout == torch.sparse_csr:
if mask is None:
raise ValueError(
"masked amax expects explicit mask for sparse_csr tensor input"
Expand All @@ -1246,7 +1246,7 @@ def amax(
)
else:
raise ValueError(
f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {input.layout} tensor)"
f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
)


Expand Down Expand Up @@ -1318,11 +1318,11 @@ def argmax(
if dtype is None:
dtype = input.dtype
mask_input = _combine_input_and_mask(argmax, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
else:
raise ValueError(
f"masked argmax expects strided tensor (got {input.layout} tensor)"
f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1344,11 +1344,11 @@ def argmin(
if dtype is None:
dtype = input.dtype
mask_input = _combine_input_and_mask(argmin, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
else:
raise ValueError(
f"masked argmin expects strided tensor (got {input.layout} tensor)"
f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
)


Expand Down Expand Up @@ -1441,7 +1441,7 @@ def median(
if not is_float:
input = input.to(dtype=torch.float)
mask_input = _combine_input_and_mask(median, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
output = torch.nanmedian(mask_input, dim_, keepdim).values
if is_float:
return output
Expand All @@ -1453,7 +1453,7 @@ def median(
)
else:
raise ValueError(
f"masked median expects strided tensor (got {input.layout} tensor)"
f"masked median expects strided tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1470,11 +1470,11 @@ def logsumexp(
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)
mask_input = _combine_input_and_mask(logsumexp, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
else:
raise ValueError(
f"masked logsumexp expects strided tensor (got {input.layout} tensor)"
f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
)


Expand Down Expand Up @@ -1525,14 +1525,14 @@ def norm(
if dtype is None:
dtype = input.dtype
mask_input = _combine_input_and_mask(norm, input, mask, ord)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
dim_ = _canonical_dim(dim, input.ndim)
return torch.linalg.vector_norm(
mask_input, ord, dim_, bool(keepdim), dtype=dtype
)
else:
raise ValueError(
f"masked norm expects strided tensor (got {input.layout} tensor)"
f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
)


Expand Down Expand Up @@ -1667,11 +1667,11 @@ def softmax(
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
mask_input = _combine_input_and_mask(amax, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
else:
raise ValueError(
f"masked softmax expects strided tensor (got {input.layout} tensor)"
f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1687,11 +1687,11 @@ def log_softmax(
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
mask_input = _combine_input_and_mask(amax, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
else:
raise ValueError(
f"masked log_softmax expects strided tensor (got {input.layout} tensor)"
f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1707,11 +1707,11 @@ def softmin(
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
mask_input = _combine_input_and_mask(amin, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
else:
raise ValueError(
f"masked softmin expects strided tensor (got {input.layout} tensor)"
f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
)


Expand All @@ -1730,13 +1730,13 @@ def normalize(
dim_ = _canonical_dim(dim, input.ndim)[0]
# TODO: eliminate mask_input as unnecessary when using masked divide.
mask_input = _combine_input_and_mask(sum, input, mask)
if input.layout == torch.strided:
if mask_input.layout == torch.strided:
nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
# TODO: replace torch.maximum with masked maximum when available.
denom = torch.maximum(nrm_, nrm_.new_full([], eps))
# TODO: replace torch.divide with masked divide when available.
return torch.divide(mask_input, denom)
else:
raise ValueError(
f"masked normalize expects strided tensor (got {input.layout} tensor)"
f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
)
2 changes: 1 addition & 1 deletion torch/masked/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .maskedtensor.core import is_masked_tensor, MaskedTensor
from .maskedtensor.matmul import masked_bmm
from .maskedtensor.creation import as_masked_tensor, masked_tensor
2 changes: 0 additions & 2 deletions torch/masked/maskedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from .binary import _apply_native_binary, _is_native_binary
from .core import is_masked_tensor, MaskedTensor
from .functions import multi_head_attention_forward
from .matmul import _apply_native_matmul, _is_native_matmul, masked_bmm
from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
from .reductions import _apply_reduction, _is_reduction
from .unary import _apply_native_unary, _is_native_unary
5 changes: 4 additions & 1 deletion torch/masked/maskedtensor/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def _binary_helper(fn, args, kwargs, inplace):
)

args0_layout = data_args[0].layout
same_layout = args0_layout == data_args[1].layout
same_layout = (
(torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
(args0_layout == data_args[1].layout)
)

if args0_layout == torch.sparse_coo:
if same_layout:
Expand Down
18 changes: 2 additions & 16 deletions torch/masked/maskedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ def __init__(self, data, mask, requires_grad=False):
self._validate_members()

@staticmethod
def from_values(data, mask):
def _from_values(data, mask):
""" Differentiable constructor for MaskedTensor """
class Constructor(torch.autograd.Function):
@staticmethod
def forward(ctx, data, mask):
return MaskedTensor(data.clone(), mask.clone())
return MaskedTensor(data, mask)

@staticmethod
def backward(ctx, grad_output):
Expand Down Expand Up @@ -383,10 +383,6 @@ def __repr__(self):
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}

if func is torch.nn.functional.multi_head_attention_forward:
from .functions import multi_head_attention_forward as mha_mt
return mha_mt(*args, **kwargs)

from .reductions import _apply_reduction, _is_reduction
if _is_reduction(func):
return _apply_reduction(func, *args, **kwargs)
Expand Down Expand Up @@ -439,16 +435,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if _is_native_binary(func):
return _apply_native_binary(func, *args, **kwargs)

from .matmul import _apply_native_matmul, _is_native_matmul

if _is_native_matmul(func):
return _apply_native_matmul(func, *args, **kwargs)

if func in [torch.ops.aten.mm, torch.ops.aten.bmm]:
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2, len_kwargs=0)
return cls.matmul(args[0], args[1], func) # type: ignore[call-arg]

# Doesn't work for addmm where the first argument is a Tensor
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
if func is torch.ops.aten.stride:
Expand Down
25 changes: 25 additions & 0 deletions torch/masked/maskedtensor/creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

from .core import MaskedTensor, is_masked_tensor

__all__ = [
"as_masked_tensor",
"masked_tensor",
]


""""
These two factory functions are intended to mirror
torch.tensor - guaranteed to be a leaf node
torch.as_tensor - differentiable constructor that preserves the autograd history
"""

def masked_tensor(data, mask, requires_grad=False):
assert not is_masked_tensor(data)
assert not is_masked_tensor(mask)
return MaskedTensor(data, mask, requires_grad)

def as_masked_tensor(data, mask):
assert not is_masked_tensor(data)
assert not is_masked_tensor(mask)
return MaskedTensor._from_values(data, mask)
Loading

0 comments on commit 0c46e3e

Please sign in to comment.