Skip to content

Commit

Permalink
unify reduction types from different operators: scatter, scatter_redu…
Browse files Browse the repository at this point in the history
…ce, segment_reduce (pytorch#91499)

The target of this PR is to unify `ReductionType` for reduce operators so that we have the same set of reduce utils for `init`, or `update` for vectorization.
Pull Request resolved: pytorch#91499
Approved by: https://github.com/ngimel
  • Loading branch information
mingfeima authored and pytorchmergebot committed Jan 13, 2023
1 parent a70387f commit eb7b897
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 218 deletions.
40 changes: 40 additions & 0 deletions aten/src/ATen/native/ReductionType.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include <c10/core/Scalar.h>

namespace at { namespace native {

enum ReductionType {MAX, MEAN, MIN, SUM, PROD};

static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "amax") {
return ReductionType::MAX;
} else if (reduce == "mean") {
return ReductionType::MEAN;
} else if (reduce == "amin") {
return ReductionType::MIN;
} else if (reduce == "sum") {
return ReductionType::SUM;
} else if (reduce == "prod") {
return ReductionType::PROD;
} else {
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
}
}

// used for `scatter_reduce`, old options for BC.
static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
if (use_new_options) {
return get_reduction_enum(reduce);
} else {
if (reduce == "add") {
return ReductionType::SUM;
} else if (reduce == "multiply") {
return ReductionType::PROD;
} else {
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
}
}
}

}} // at::native
64 changes: 24 additions & 40 deletions aten/src/ATen/native/SegmentReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,9 @@ DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);

namespace {

SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "max") {
return SegmentReductionType::MAX;
} else if (reduce == "mean") {
return SegmentReductionType::MEAN;
} else if (reduce == "min") {
return SegmentReductionType::MIN;
} else if (reduce == "sum") {
return SegmentReductionType::SUM;
} else if (reduce == "prod") {
return SegmentReductionType::PROD;
} else {
TORCH_CHECK(false, "unsupported reduction given! ", reduce);
}
}

template <typename T, bool is_offsets_like=false>
void _segment_reduce_lengths_cpu_kernel1(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const T* lengths_data,
int64_t axis,
Expand Down Expand Up @@ -90,15 +74,15 @@ void _segment_reduce_lengths_cpu_kernel1(
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
} else if (reduction == ReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
reduction == ReductionType::MEAN ||
reduction == ReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
initial_value = 1;
}

Expand All @@ -107,19 +91,19 @@ void _segment_reduce_lengths_cpu_kernel1(
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + inner_idx;
const auto val = values_data[data_index];
if (reduction == SegmentReductionType::MAX) {
if (reduction == ReductionType::MAX) {
initial_value = at::_isnan(val)
? val
: std::max<scalar_t>(initial_value, val);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
reduction == ReductionType::MEAN ||
reduction == ReductionType::SUM) {
initial_value = initial_value + val;
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
initial_value = at::_isnan(val)
? val
: std::min<scalar_t>(initial_value, val);
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
initial_value = initial_value * val;
}
}
Expand All @@ -128,10 +112,10 @@ void _segment_reduce_lengths_cpu_kernel1(
TORCH_CHECK(segment_length >= 0);

if (segment_length == 0 && !initial.has_value() &&
reduction == SegmentReductionType::MEAN) {
reduction == ReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN &&
reduction == ReductionType::MEAN &&
segment_length > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / segment_length;
}
Expand All @@ -145,7 +129,7 @@ void _segment_reduce_lengths_cpu_kernel1(
}

Tensor _segment_reduce_lengths_cpu_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& lengths,
int64_t axis,
Expand All @@ -171,7 +155,7 @@ Tensor _segment_reduce_lengths_cpu_kernel(
}

Tensor _segment_reduce_offsets_cpu_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& offsets,
int64_t axis,
Expand Down Expand Up @@ -201,7 +185,7 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const T* lengths_data,
int64_t axis,
const c10::optional<Scalar>& initial,
Expand Down Expand Up @@ -234,7 +218,7 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
const auto* values_data = data_contig.data_ptr<scalar_t>();
// Used to calculate exclusive prod
scalar_t initial_prod_value;
if (reduction == SegmentReductionType::PROD) {
if (reduction == ReductionType::PROD) {
if (initial.has_value()) {
initial_prod_value = initial.value().to<scalar_t>();
} else {
Expand Down Expand Up @@ -265,8 +249,8 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
for (const auto inner_idx : c10::irange(inner_offset)) {
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output_stride_axis + inner_idx;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
if (reduction == ReductionType::MAX ||
reduction == ReductionType::MIN) {
int64_t counter = 0;
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
Expand All @@ -290,21 +274,21 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
grad_input_data[data_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
} else if (reduction == ReductionType::MEAN) {
auto grad_val = grad_data[output_index] / segment_length;
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
} else if (reduction == ReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
const auto& grad_val = grad_data[output_index] * output_data[output_index];
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
Expand Down Expand Up @@ -337,7 +321,7 @@ Tensor _segment_reduce_cpu_lengths_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& lengths_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
Expand Down Expand Up @@ -370,7 +354,7 @@ Tensor _segment_reduce_cpu_offsets_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& offsets_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
Expand Down
11 changes: 5 additions & 6 deletions aten/src/ATen/native/SegmentReduce.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReductionType.h>
#include <c10/core/Scalar.h>
#include <c10/util/Optional.h>

Expand All @@ -9,18 +10,16 @@ class Tensor;

namespace native {

enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};

using segment_reduce_lengths_fn = Tensor (*)(
SegmentReductionType,
ReductionType,
const Tensor&,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);

using segment_reduce_offsets_fn = Tensor (*)(
SegmentReductionType,
ReductionType,
const Tensor&,
const Tensor&,
int64_t,
Expand All @@ -31,7 +30,7 @@ using segment_reduce_lengths_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
SegmentReductionType,
ReductionType,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
Expand All @@ -41,7 +40,7 @@ using segment_reduce_offsets_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
SegmentReductionType,
ReductionType,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
Expand Down
Loading

0 comments on commit eb7b897

Please sign in to comment.