Skip to content

Commit

Permalink
[maskedtensor] first commit, core and creation (pytorch#82836)
Browse files Browse the repository at this point in the history
  • Loading branch information
george-qi authored and pytorchmergebot committed Aug 16, 2022
1 parent 84146f3 commit 94ba085
Show file tree
Hide file tree
Showing 7 changed files with 707 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Features described in this documentation are classified by release status:
quantization
rpc
torch.random <random>
masked
nested
sparse
storage
Expand Down
11 changes: 11 additions & 0 deletions docs/source/masked.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
torch.masked
============

.. automodule:: torch.masked
.. automodule:: torch.masked.maskedtensor

Introduction
++++++++++++

WIP. For more information, you can go to github.com/pytorch/maskedtensor for the source code
or http://pytorch.org/maskedtensor for a number of tutorials
100 changes: 100 additions & 0 deletions test/test_maskedtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Owner(s): ["module: masked operators"]

import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
make_tensor,
instantiate_parametrized_tests,
)

from torch.testing._internal.common_methods_invocations import (
SampleInput,
)

from torch.masked.maskedtensor.core import _masks_match, _tensors_match


def _compare_mt_t(mt_result, t_result):
mask = mt_result.get_mask()
mt_result_data = mt_result.get_data()
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
mask = mask.to_dense()
if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}:
mt_result_data = mt_result_data.to_dense()
a = mt_result_data.detach().masked_fill_(~mask, 0)
b = t_result.detach().masked_fill_(~mask, 0)
if not _tensors_match(a, b, exact=False):
raise ValueError("The data in MaskedTensor a and Tensor b do not match")

def _compare_mts(mt1, mt2):
mt_data1 = mt1.get_data()
mt_data2 = mt2.get_data()
if mt_data1.layout != mt_data2.layout:
raise ValueError("mt1's data and mt2's data do not have the same layout. "
f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}")

mask = mt1.get_mask()
mask2 = mt2.get_mask()
if not _masks_match(mt1, mt2):
raise ValueError("mt1 and mt2 must have matching masks")
if mask.layout != mask2.layout:
raise ValueError("mt1's mask and mt2's mask do not have the same layout. "
f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}")
if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
mask = mask.to_dense()

if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}:
mt_data1 = mt_data1.to_dense()
mt_data2 = mt_data2.to_dense()
a = mt_data1.detach().masked_fill_(~mask, 0)
b = mt_data2.detach().masked_fill_(~mask, 0)

if not _tensors_match(a, b, exact=False):
raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")

def _create_random_mask(shape, device):
return make_tensor(
shape, device=device, dtype=torch.bool, low=0, high=1, requires_grad=False
)

def _generate_sample_data(
device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
):
assert layout in {
torch.strided,
torch.sparse_coo,
torch.sparse_csr,
}, "Layout must be strided/sparse_coo/sparse_csr"
shapes = [
[],
[2],
[3, 5],
[3, 2, 1, 2],
]
inputs = []
for s in shapes:
data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type]
mask = _create_random_mask(s, device)
if layout == torch.sparse_coo:
mask = mask.to_sparse_coo().coalesce()
data = data.sparse_mask(mask).requires_grad_(requires_grad)
elif layout == torch.sparse_csr:
if data.ndim != 2 and mask.ndim != 2:
continue
mask = mask.to_sparse_csr()
data = data.sparse_mask(mask)
inputs.append(SampleInput(data, kwargs={"mask": mask}))
return inputs


class TestBasics(TestCase):
def sample_test(self):
return


instantiate_parametrized_tests(TestBasics)


if __name__ == '__main__':
run_tests()
2 changes: 2 additions & 0 deletions torch/masked/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .maskedtensor.core import is_masked_tensor, MaskedTensor
from .maskedtensor.creation import as_masked_tensor, masked_tensor
Empty file.
Loading

0 comments on commit 94ba085

Please sign in to comment.