Skip to content

Commit

Permalink
Consistent compute numel/contiguous strategy with SymInts (pytorch#85858
Browse files Browse the repository at this point in the history
)

Previously, our handling for contiguity was inconsistent in the following ways:

- is_strides_like 2d/3d and is_non_overlapping_and_dense always were computed
  based on sizes_and_strides_, even if you had symbolic ints
- Furthermore, even if you set custom policy for strides, these quantities were
  not overridable by subclasses
- Furthermore, we didn't even store these fields on ExtraMeta
- We duplicate implementations of compute_contiguous (plain, channels last,
  channels last 3d)
- We inconsistently called refresh_numel()/refresh_contiguous(), versus
  recomputing it ourselves

This factor makes a consistent strategy for all of the boolean fields, and
for numel computation.  After this refactor:

- All layout boolean fields are interposable via strides policy
  and can be overridden from Python; you will never access a garbage field
- All layout boolean fields are on ExtraMeta
- You can always call refresh_numel/contiguous, no matter if your Tensor is
  contiguous or not
- The numel/layout boolean fields are always populated consistently with
  the sizes strides fields (either on Tensor or ExtraMeta), even if you
  have custom policy
- There is only one implementation of the actual computation logic

Signed-off-by: Edward Z. Yang <[email protected]>

Differential Revision: [D39907696](https://our.internmc.facebook.com/intern/diff/D39907696)
Pull Request resolved: pytorch#85858
Approved by: https://github.com/albanD
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 30, 2022
1 parent 84a06d7 commit 3b6588a
Show file tree
Hide file tree
Showing 18 changed files with 617 additions and 192 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::IntArrayRef s
auto inferred_size = at::infer_size_dv(size, self.numel());
auto stride = at::detail::computeStride(self.sizes(), self.strides(), inferred_size);
TORCH_INTERNAL_ASSERT(stride.has_value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride.value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
return out;
}

Expand Down
36 changes: 26 additions & 10 deletions c10/core/MemoryFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
// input
// 3. All helper functions have similar comments, only 1st helper function is
// commented here.
template <typename T>
inline bool is_channels_last_strides_2d_s4(
const IntArrayRef sizes,
const IntArrayRef strides) {
int64_t min = 0;
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
// special case for trivial C dimension. default to NCHW
if (strides[1] == 0) {
return false;
Expand Down Expand Up @@ -155,10 +156,11 @@ inline bool is_channels_last_strides_2d_s4(
return true;
}

template <typename T>
inline bool is_channels_last_strides_3d_s5(
const IntArrayRef sizes,
const IntArrayRef strides) {
int64_t min = 0;
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
T min = 0;
if (strides[1] == 0) {
return false;
}
Expand Down Expand Up @@ -230,9 +232,10 @@ inline bool is_channels_last_strides_3d_s5(
// implementation. Please check the helper functions
// (is_channels_last_strides_*d_s*) for more details.

template <typename T>
inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 4:
return is_channels_last_strides_2d_s4(sizes, strides);
Expand All @@ -244,9 +247,10 @@ inline bool is_channels_last_strides_2d(
}
}

template <typename T>
inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
const ArrayRef<T> sizes,
const ArrayRef<T> strides) {
switch (sizes.size()) {
case 5:
return is_channels_last_strides_3d_s5(sizes, strides);
Expand All @@ -258,4 +262,16 @@ inline bool is_channels_last_strides_3d(
}
}

inline bool is_channels_last_strides_2d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_2d<int64_t>(sizes, strides);
}

inline bool is_channels_last_strides_3d(
const IntArrayRef sizes,
const IntArrayRef strides) {
return is_channels_last_strides_3d<int64_t>(sizes, strides);
}

} // namespace c10
Loading

0 comments on commit 3b6588a

Please sign in to comment.