Skip to content

Commit

Permalink
std/var: support floating point correction value (pytorch#94073)
Browse files Browse the repository at this point in the history
Ref pytorch#61492 (comment)

The array API specifies correction to be `Union[int, float]` while we currently only support integers.
https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html

As std/var is calculated currently, the final count of elements is already done
in floating point so we can make the correction floating point without any loss
of precision or generality.

Pull Request resolved: pytorch#94073
Approved by: https://github.com/ezyang
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Feb 23, 2023
1 parent 56aed2a commit bc438af
Show file tree
Hide file tree
Showing 32 changed files with 220 additions and 206 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d29eb67c27af0f18d4f487d76b86f43b0a69aade
503401a24e532a9019ef140199319221294045ee
57 changes: 28 additions & 29 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,7 @@ TORCH_IMPL_FUNC(argmin_out)
argmax_argmin_impl(self, dim, keepdim, result, argmin_stub);
}

static double std_var_all_cpu(const Tensor& self, int64_t correction, bool take_sqrt) {
static double std_var_all_cpu(const Tensor& self, double correction, bool take_sqrt) {
const auto dtype = self.scalar_type();
TORCH_CHECK(dtype == kDouble || dtype == kFloat,
"std_var_all: Unsupported dtype ", dtype);
Expand Down Expand Up @@ -1645,7 +1645,7 @@ static double std_var_all_cpu(const Tensor& self, int64_t correction, bool take_
0, iter.numel(), at::internal::GRAIN_SIZE, 0.0, reduction, std::plus<>{});

const auto var = [&] () __ubsan_ignore_float_divide_by_zero__ {
return sum_dx2 / std::max(int64_t{0}, self.numel() - correction);
return sum_dx2 / std::max(0.0, self.numel() - correction);
}();
const auto result = take_sqrt ? std::sqrt(var) : var;

Expand All @@ -1659,7 +1659,7 @@ static double std_var_all_cpu(const Tensor& self, int64_t correction, bool take_

static Tensor& std_var_out(
const char* fname, Tensor& result, const Tensor& self,
at::OptionalIntArrayRef dim, c10::optional<int64_t> correction_opt,
at::OptionalIntArrayRef dim, const c10::optional<Scalar>& correction_opt,
bool keepdim, bool take_sqrt) {
TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda(),
"std and var only supports tensors on a CPU or CUDA device, but got: ",
Expand Down Expand Up @@ -1703,7 +1703,7 @@ static Tensor& std_var_out(
}

// Computation for floating point
const auto correction = correction_opt.value_or(1);
const auto correction = correction_opt.value_or(1).toDouble();
ScalarType dtype = get_dtype_from_result(result, {});
auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()),
Expand All @@ -1730,7 +1730,7 @@ static Tensor& std_var_out(

static std::tuple<Tensor&, Tensor&> std_var_mean_out(
const char* fname, Tensor& result1, Tensor& result2, const Tensor& self,
at::OptionalIntArrayRef dim, c10::optional<int64_t> correction_opt,
at::OptionalIntArrayRef dim, const c10::optional<Scalar>& correction_opt,
bool keepdim, bool take_sqrt) {
AT_ASSERT(result1.defined() && result2.defined());
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
Expand Down Expand Up @@ -1784,7 +1784,7 @@ static std::tuple<Tensor&, Tensor&> std_var_mean_out(
}

// Computation for floating point
const auto correction = correction_opt.value_or(1);
const auto correction = correction_opt.value_or(1).toDouble();
ScalarType dtype = get_dtype_from_result(result1, {});
auto iter =
make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
Expand All @@ -1803,30 +1803,29 @@ std::tuple<Tensor, Tensor> var_mean(
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::var_mean(
self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0),
keepdim);
}

std::tuple<Tensor, Tensor> std_mean(
const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::std_mean(
self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0),
keepdim);
}

std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
return at::std_mean(
self, /*dim=*/c10::nullopt,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0));
}

std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
return at::var_mean(
self, /*dim=*/c10::nullopt,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0));
}

std::tuple<Tensor&, Tensor&> var_mean_out(
Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
int64_t correction, bool keepdim) {
Expand All @@ -1841,7 +1840,7 @@ static TensorOptions options_to_value_type(TensorOptions opts) {

std::tuple<Tensor, Tensor> var_mean(
const Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim) {
const c10::optional<Scalar>& correction, bool keepdim) {
Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
Tensor result2 = at::empty({0}, self.options());
return std_var_mean_out(
Expand All @@ -1850,7 +1849,7 @@ std::tuple<Tensor, Tensor> var_mean(

std::tuple<Tensor, Tensor> std_mean(
const Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim) {
const c10::optional<Scalar>& correction, bool keepdim) {
Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
Tensor result2 = at::empty({0}, self.options());
return std_var_mean_out(
Expand All @@ -1860,59 +1859,59 @@ std::tuple<Tensor, Tensor> std_mean(
Tensor var(const Tensor& self, bool unbiased) {
return at::var(
self, /*dim=*/c10::nullopt,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0));
}

Tensor var(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::var(
self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0),
keepdim);
}

Tensor& var_out(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
return at::var_out(
result, self, /*dim=*/at::OptionalIntArrayRef(dim),
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}),
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0),
keepdim);
}

Tensor std(const Tensor& self, bool unbiased) {
return at::std(
self, /*dim=*/c10::nullopt, /*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}));
self, /*dim=*/c10::nullopt, /*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0));
}

Tensor std(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
return at::std(self, dim,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}), keepdim);
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
}

Tensor& std_out(const Tensor& self, at::OptionalIntArrayRef opt_dim, bool unbiased, bool keepdim, Tensor& result) {
return at::std_out(result, self, opt_dim,
/*correction=*/c10::make_optional<int64_t>({unbiased ? 1 : 0}), keepdim);
/*correction=*/c10::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
}

Tensor std(const Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim) {
const c10::optional<Scalar>& correction, bool keepdim) {
Tensor result = at::empty({0}, options_to_value_type(self.options()));
return std_var_out("std", result, self, dim, correction, keepdim, true);
}

Tensor& std_out(
const Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim, Tensor& result) {
const c10::optional<Scalar>& correction, bool keepdim, Tensor& result) {
return std_var_out("std", result, self, dim, correction, keepdim, true);
}

Tensor& var_out(
const Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim, Tensor& result) {
const c10::optional<Scalar>& correction, bool keepdim, Tensor& result) {
return std_var_out("var", result, self, dim, correction, keepdim, false);
}

Tensor var(
const Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim) {
const c10::optional<Scalar>& correction, bool keepdim) {
Tensor result = at::empty({0}, options_to_value_type(self.options()));
return std_var_out("var", result, self, dim, correction, keepdim, false);
}
Expand Down Expand Up @@ -1942,32 +1941,32 @@ std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim, bool unb
return at::std_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
}

Tensor std(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction, bool keepdim) {
Tensor std(const Tensor& self, DimnameList dim, const c10::optional<Scalar>& correction, bool keepdim) {
return at::std(self, dimnames_to_positions(self, dim), correction, keepdim);
}

Tensor& std_out(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction,
Tensor& std_out(const Tensor& self, DimnameList dim, const c10::optional<Scalar>& correction,
bool keepdim, Tensor& result) {
return at::std_out(result, self, dimnames_to_positions(self, dim), correction, keepdim);
}

Tensor var(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction, bool keepdim) {
Tensor var(const Tensor& self, DimnameList dim, const c10::optional<Scalar>& correction, bool keepdim) {
return at::var(self, dimnames_to_positions(self, dim), correction, keepdim);
}

Tensor& var_out(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction,
Tensor& var_out(const Tensor& self, DimnameList dim, const c10::optional<Scalar>& correction,
bool keepdim, Tensor& result) {
return at::var_out(
result, self, dimnames_to_positions(self, dim), correction, keepdim);
}

std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim,
c10::optional<int64_t> correction, bool keepdim) {
const c10::optional<Scalar>& correction, bool keepdim) {
return at::var_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
}

std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim,
c10::optional<int64_t> correction, bool keepdim) {
const c10::optional<Scalar>& correction, bool keepdim) {
return at::std_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/ReduceOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ DECLARE_DISPATCH(reduce_fn, argmax_stub);
DECLARE_DISPATCH(reduce_fn, argmin_stub);

using reduce_std_var_function =
void (*)(TensorIterator&, int64_t correction, bool take_sqrt);
void (*)(TensorIterator&, double correction, bool take_sqrt);
DECLARE_DISPATCH(reduce_std_var_function, std_var_stub);

using reduce_norm_fn =
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/SharedReduceOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct WelfordData {

template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
struct WelfordOps {
index_t correction;
acc_scalar_t correction;
bool take_sqrt;
public:
using acc_t = WelfordData<acc_scalar_t, index_t>;
Expand Down Expand Up @@ -154,7 +154,7 @@ struct WelfordOps {
};
}
#endif
C10_HOST_DEVICE WelfordOps(index_t correction, bool take_sqrt)
C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
: correction(correction), take_sqrt(take_sqrt) {}
};

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ static void mean_kernel_impl(TensorIterator& iter) {
});
}

static void std_var_kernel_impl(TensorIterator& iter, int64_t correction, bool take_sqrt) {
static void std_var_kernel_impl(TensorIterator& iter, double correction, bool take_sqrt) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "std_cpu", [&] {
binary_kernel_reduce(
iter,
Expand Down
13 changes: 4 additions & 9 deletions aten/src/ATen/native/cuda/ReduceMomentKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,16 @@
namespace at::native {

template <typename scalar_t, typename out_t=scalar_t>
void std_var_kernel_impl(TensorIterator& iter, int32_t correction, bool take_sqrt) {
void std_var_kernel_impl(TensorIterator& iter, double correction, bool take_sqrt) {
// reducing unrolling factor to 2 for welford kernel
// This is necessary to lower register usage that leads to register spills.
using accscalar_t = at::acc_type<scalar_t, true>;
using ops_t = WelfordOps<scalar_t, accscalar_t, int32_t, thrust::pair<out_t, out_t>>;
gpu_reduce_kernel<scalar_t, out_t, 2>(
iter, ops_t{correction, take_sqrt}, typename ops_t::acc_t{});
ops_t ops(static_cast<accscalar_t>(correction), take_sqrt);
gpu_reduce_kernel<scalar_t, out_t, 2>(iter, ops, typename ops_t::acc_t{});
}

static void std_var_kernel_cuda(TensorIterator& iter, int64_t correction, bool take_sqrt) {
using limits = std::numeric_limits<int32_t>;
TORCH_CHECK(
correction < limits::max() && correction > limits::min(),
"The correction argument for std and var computation on CUDA must "
"fit within a 32-bit integer, but got ", correction);
static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool take_sqrt) {
const auto input_dtype = iter.input_dtype();
if (input_dtype == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
Expand Down
13 changes: 7 additions & 6 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c1
Tensor std_var_common_impl_mps(
const Tensor & input_t,
at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction,
const c10::optional<Scalar>& correction,
bool keepdim,
StdVarType stdVarType) {
using CachedGraph = MPSUnaryCachedGraph;
Expand All @@ -737,8 +737,8 @@ Tensor std_var_common_impl_mps(
}
}

bool use_correction = !(correction.has_value() && correction.value() == 0);
const auto correction_value = correction.value_or(1);
bool use_correction = !(correction.has_value() && correction.value().toDouble() == 0);
const auto correction_value = correction.value_or(1.0).toDouble();
int64_t correction_n = 1;

MPSGraphCache* cache_ = MPSGraphCache::getInstance();
Expand Down Expand Up @@ -858,7 +858,8 @@ Tensor std_var_common_impl_mps(
return output_t;
}

double bessel_correction = static_cast<double>(correction_n) / static_cast<double>(correction_n - correction_value);
double dof = std::max(0.0, correction_n - correction_value);
double bessel_correction = correction_n / dof;
auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
Expand Down Expand Up @@ -929,7 +930,7 @@ Tensor std_var_common_impl_mps(
Tensor var_mps(
const Tensor & input_t,
at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction,
const c10::optional<Scalar>& correction,
bool keepdim)
{
return std_var_common_impl_mps(input_t, dim, correction, keepdim, STANDARD_VARIANCE);
Expand All @@ -938,7 +939,7 @@ Tensor var_mps(
Tensor std_mps(
const Tensor & input_t,
at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction,
const c10::optional<Scalar>& correction,
bool keepdim)
{
return std_var_common_impl_mps(input_t, dim, correction, keepdim, STANDARD_DEVIATION);
Expand Down
Loading

0 comments on commit bc438af

Please sign in to comment.