Skip to content

Commit

Permalink
Implement correction argument in torch.masked.{std,var} (pytorch#87118)
Browse files Browse the repository at this point in the history
This makes the signature of `torch.masked.std` and `var` more consistent with the global namespace variant and also updates the sample inputs to repurpose the existing `sample_inputs_std_var` inputs which fully exercise the `correction` argument.

Pull Request resolved: pytorch#87118
Approved by: https://github.com/cpuhrsch
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Dec 8, 2022
1 parent a6593d6 commit 4543614
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 56 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def process(device_type):
"linalg.pinv.singular": {f32, f64},
"masked.norm": {f16},
"masked.normalize": {f16},
"masked.var": {f16},
"masked_fill": {f16},
"masked_scatter": {f16, f32, f64},
"masked_select": {b8, f16, f32, f64, i32, i64},
Expand Down
32 changes: 22 additions & 10 deletions torch/masked/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,14 +1538,22 @@ def norm(

def _std_var(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
dim: DimOrDims,
unbiased: Optional[bool],
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
take_sqrt: Optional[bool] = False,
correction: Optional[int],
keepdim: Optional[bool],
dtype: Optional[DType],
mask: Optional[Tensor],
take_sqrt: Optional[bool],
) -> Tensor:
assert (unbiased is None or correction is None), "Only one of unbiased and correction may be given"
correction_int = 1
if unbiased is not None:
correction_int = 1 if unbiased else 0
if correction is not None:
correction_int = correction

if dtype is None:
dtype = input.dtype
if not (dtype.is_floating_point or dtype.is_complex):
Expand Down Expand Up @@ -1584,8 +1592,8 @@ def _std_var(
)
if not keepdim:
count = count.reshape(total.shape)
if unbiased:
count = torch.subtract(count, 1)
if correction_int != 0:
count = torch.subtract(count, correction_int)
count = torch.maximum(count, count.new_zeros([]))
output = torch.divide(total, count).to(dtype=dtype)
if take_sqrt:
Expand All @@ -1601,8 +1609,9 @@ def _std_var(
def var(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
unbiased: Optional[bool] = None,
*,
correction: Optional[int] = None,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
Expand All @@ -1619,6 +1628,7 @@ def var(
input=input,
dim=dim,
unbiased=unbiased,
correction=correction,
keepdim=keepdim,
dtype=dtype,
mask=mask,
Expand All @@ -1630,8 +1640,9 @@ def var(
def std(
input: Union[Tensor, MaskedTensor],
dim: DimOrDims = None,
unbiased: Optional[bool] = False,
unbiased: Optional[bool] = None,
*,
correction: Optional[int] = None,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None,
Expand All @@ -1648,6 +1659,7 @@ def std(
input=input,
dim=dim,
unbiased=unbiased,
correction=correction,
keepdim=keepdim,
dtype=dtype,
mask=mask,
Expand Down
133 changes: 92 additions & 41 deletions torch/testing/_internal/opinfo/definitions/_masked.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from collections.abc import Sequence
from functools import partial
from typing import List

Expand Down Expand Up @@ -223,51 +224,101 @@ def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
)


def reference_masked_std_var(
numpy_fn,
):
ref = reference_reduction_numpy(numpy_fn)

# Translate unbiased or correction arguments into ddof
def func(
input,
dim=None,
unbiased=None,
*,
correction=None,
**kwargs,
):
ddof = 1
if unbiased is not None:
ddof = 1 if unbiased else 0
if correction is not None:
ddof = correction

if isinstance(dim, Sequence):
dim = tuple(dim)

return ref(input, dim, ddof=ddof, **kwargs)

return func


def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked std/var."""
for unbiased in [False, True]:
for sample_input in sample_inputs_masked_reduction(
kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
from torch.testing._internal.common_methods_invocations import sample_inputs_std_var

def masked_samples():
for sample_input in sample_inputs_std_var(
op_info, device, dtype, requires_grad, **kwargs
):
if sample_input.args:
dim = sample_input.args[0]
sample_input_args = (
sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
if len(sample_input.args) and isinstance(sample_input.args[0], bool):
continue # masked.{std, var} doesn't support `.var(unbiased)`

for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs
):
sample_input_args, sample_input_kwargs = sample_input.args, dict(
mask=mask, **sample_input.kwargs
)
sample_input_kwargs = sample_input.kwargs.copy()
else:
dim = sample_input.kwargs.get("dim")
sample_input_args = sample_input.args
sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
if requires_grad:
if sample_input_kwargs.get("mask") is None:
orig_count = torch.masked.sum(
torch.ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
)
else:
inmask = torch.masked._input_mask(
sample_input.input, *sample_input_args, **sample_input_kwargs
)
orig_count = torch.masked.sum(
inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
mask=inmask,
)
if orig_count.min() <= int(unbiased) + 1:
# Skip samples that lead to singularities in var
# computation resulting nan values both in var and
# autograd output that test_grad_fn cannot handle
# correctly. Also, skip samples when the autograd output
# for std could not be handled correctly due to torch.sqrt
continue
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
yield SampleInput(
sample_input.input.detach().requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)
if (
not requires_grad
and dtype.is_floating_point
and sample_input.input.ndim == 2
and mask is not None
and mask.shape == sample_input.input.shape
):
for v in [torch.inf, -torch.inf, torch.nan]:
t = sample_input.input.detach()
t.diagonal(0, -2, -1).fill_(v)
yield SampleInput(
t.requires_grad_(requires_grad),
args=sample_input_args,
kwargs=sample_input_kwargs,
)

for sample_input in masked_samples():
correction = sample_input.kwargs.get("correction")
if correction is None:
correction = int(sample_input.kwargs.get("unbiased", True))

dim = sample_input.kwargs.get("dim", None)

if sample_input.kwargs.get("mask") is None:
orig_count = torch.masked.sum(
torch.ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
)
else:
inmask = torch.masked._input_mask(
sample_input.input, *sample_input.args, **sample_input.kwargs
)
orig_count = torch.masked.sum(
inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
dim,
keepdim=True,
mask=inmask,
)
if orig_count.min() <= correction + 1:
# Skip samples that lead to nans in var computation
continue

yield sample_input


def sample_inputs_masked_softmax(
Expand Down Expand Up @@ -860,7 +911,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
),
ReductionOpInfo(
"masked.var",
ref=reference_reduction_numpy(np.var)
ref=reference_masked_std_var(np.var)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
Expand Down Expand Up @@ -938,7 +989,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
),
ReductionOpInfo(
"masked.std",
ref=reference_reduction_numpy(np.std)
ref=reference_masked_std_var(np.std)
if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
else None,
method_variant=None,
Expand Down
5 changes: 0 additions & 5 deletions torch/testing/_internal/opinfo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,6 @@ def wrapper(x: np.ndarray, *args, **kwargs):
identity = identity.cpu()
kwargs["initial"] = identity.numpy()

if "unbiased" in keys:
unbiased = kwargs.pop("unbiased")
if unbiased is not None:
kwargs["ddof"] = int(unbiased)

result = f(x, *args, **kwargs)

# Unsqueeze reduced dimensions if NumPy does not support keepdims
Expand Down

0 comments on commit 4543614

Please sign in to comment.