Skip to content

Commit

Permalink
Ported reshape to symints and added a shim for BC (pytorch#85998)
Browse files Browse the repository at this point in the history
  • Loading branch information
ezyang authored and pytorchmergebot committed Oct 2, 2022
1 parent f88bf8d commit 3638089
Show file tree
Hide file tree
Showing 21 changed files with 118 additions and 49 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor
return at::functionalization::permute_copy_inverse(mutated_view, dims, reapply_views);
}

Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, at::IntArrayRef stride) {
Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride) {
// Note that I'm directly calling reshape(), and ignoring the strides.
// _reshape_alias() isn't available from user code, and is an implementation detail of reshape().
// Specifically, passing in the strides directly can get us into trouble in cases like:
// b = a[0]; c = b.reshape(...); c.add_(1); print(a)
// When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides,
// The call would fail because `mutated_view` doesn't have enough bytes of storage.
if (reapply_views) {
return at::_reshape_alias(mutated_view, base.sizes(), base.strides());
return at::_reshape_alias_symint(mutated_view, base.sym_sizes(), base.sym_strides());
} else {
return at::_reshape_alias_copy(mutated_view, base.sizes(), base.strides());
return at::_reshape_alias_copy_symint(mutated_view, base.sym_sizes(), base.sym_strides());
}
}

Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,33 +256,33 @@ at::Tensor _to_copy_functionalize(
// The idea with _unsafe_view is that you're guaranteed that the input
// is a temporary, and don't actually have to worry about propagating
// mutations between the input and output.
at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::IntArrayRef size) {
at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymIntArrayRef size) {
if (!at::functionalization::impl::isFunctionalTensor(self)) {
at::AutoDispatchSkipFunctionalize guard;
return at::_unsafe_view(self, size);
return at::_unsafe_view_symint(self, size);
}

auto self_ = at::functionalization::impl::from_functional_tensor(self);
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = at::_unsafe_view(self_, size);
tmp_output = at::_unsafe_view_symint(self_, size);
}

at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view(base, size);
return at::_unsafe_view_symint(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::_unsafe_view(mutated_view, base.sizes());
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
);

auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, view_meta);
// See Note [Propagating strides in the functionalization pass]
// (for _unsafe_view, I'm just manually doing the shape inference rule here instead of calling the meta function for unsafe_view)
auto inferred_size = at::infer_size_dv(size, self.numel());
auto stride = at::detail::computeStride(self.sizes(), self.strides(), inferred_size);
auto inferred_size = at::infer_size_dv(size, self.sym_numel());
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size);
TORCH_INTERNAL_ASSERT(stride.has_value());
out.unsafeGetTensorImpl()->set_sizes_and_strides(inferred_size, stride.value());
return out;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
OP_DECOMPOSE(ravel);
OP_DECOMPOSE2(repeat_interleave, self_int);
OP_DECOMPOSE2(repeat_interleave, self_Tensor);
OP_DECOMPOSE(reshape);
m.impl("reshape", native::reshape_symint);
OP_DECOMPOSE(resolve_conj);
OP_DECOMPOSE(resolve_neg);
OP_DECOMPOSE(row_stack);
Expand Down
26 changes: 13 additions & 13 deletions aten/src/ATen/functorch/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
IntArrayRef sizes) {
c10::SymIntArrayRef sizes) {

VmapDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
SymDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
auto self_ = moveBatchDimToFront(self, self_bdim);
while (self_.dim() < (int64_t)sizes_with_bdim.size()) {
self_ = self_.unsqueeze(1);
}
return std::make_tuple(self_.repeat(sizes_with_bdim), 0);
return std::make_tuple(self_.repeat_symint(sizes_with_bdim), 0);
}


Expand All @@ -136,22 +136,22 @@ std::tuple<Tensor,optional<int64_t>> diag_batch_rule(
std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
IntArrayRef size) {
c10::SymIntArrayRef size) {
auto self_ = moveBatchDimToFront(self, self_bdim);
VmapDimVector view_size(size);
SymDimVector view_size(size);
view_size.insert(view_size.begin(), self_.size(0));

// See if the view is valid. If it's not, then we copy.
// It's OK to copy, because _unsafe_view(x) guarantees that x isn't used
// anymore.
const at::DimVector inferred_size = at::infer_size_dv(view_size, self_.numel());
const auto stride = at::detail::computeStride(self_.sizes(),
self_.strides(),
const at::SymDimVector inferred_size = at::infer_size_dv(view_size, self_.sym_numel());
const auto stride = at::detail::computeStride(self_.sym_sizes(),
self_.sym_strides(),
inferred_size);
if (!stride.has_value()) {
self_ = self_.contiguous();
}
return std::make_tuple(at::_unsafe_view(self_, view_size), 0);
return std::make_tuple(at::_unsafe_view_symint(self_, view_size), 0);
}

std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef dims) {
Expand Down Expand Up @@ -286,15 +286,15 @@ std::tuple<Tensor, optional<int64_t>> select_batching_rule(const Tensor& self, o
return std::make_tuple(result, 0);
}

std::tuple<Tensor, optional<int64_t>> _reshape_alias_batch_rule(const Tensor& self, optional<int64_t> bdim, const IntArrayRef shape, const IntArrayRef strides) {
std::tuple<Tensor, optional<int64_t>> _reshape_alias_batch_rule(const Tensor& self, optional<int64_t> bdim, const c10::SymIntArrayRef shape, const c10::SymIntArrayRef strides) {
(void) strides;
TORCH_INTERNAL_ASSERT(bdim.has_value());

auto self_ = moveBatchDimToFront(self, bdim);
c10::SmallBuffer<int64_t, 5> new_shape(shape.size() + 1);
new_shape[0] = self_.size(0);
c10::SymDimVector new_shape(shape.size() + 1);
new_shape[0] = self_.sym_size(0);
std::copy(shape.begin(), shape.end(), new_shape.begin() + 1);
return std::make_tuple(at::reshape(self_, new_shape), 0);
return std::make_tuple(at::reshape_symint(self_, new_shape), 0);
}

std::tuple<Tensor, optional<int64_t>> roll_batch_rule(const Tensor& self, optional<int64_t> bdim, IntArrayRef shifts, IntArrayRef dims) {
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/native/NonSymbolicBC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>
#include <ATen/core/IListRef.h>

namespace at {
namespace native {
// This file contains non-symbolic signatures for ops that we have sym-intified the signature of.
// However, in certain cases (such as static runtime), we call the native versions of the ops directly.
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
}}
41 changes: 41 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,47 @@ Tensor alias_with_sizes_and_strides(
return self_;
}

Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());

if (self.is_mkldnn()) {
return at::_mkldnn_reshape(self, c10::asIntArrayRefSlow(shape));
}

// `computeStride` returns the proper strides to use if this
// `reshape` can be just a view.
auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);

// NB: Even though we have viewable geometry and the target strides here,
// we do not just call `as_strided` on `self` because the backward
// for `as_strided` is not as efficient as that of `view` (since the
// former is meant to handle general cases).
//
// Similarly we don't call `view` because it duplicates some of the work
// we've already done, and instead call our internal/private operator
// `_reshape_alias` that essentially does the same thing as `view` and
// `as_strided` without any of the extra overhead.
if (stride.has_value()) {
// Temporary check to revert to the old behavior/view in cases where the
// device is not supported (e.g. for XLA the operation is not supported
// so we use `view` instead).
//
// We need to do the checks here instead of in `native_functions.yaml`
// to preserve backwards compatibility.
if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
return self._reshape_alias_symint(shape, stride.value());
} else {
return self.view_symint(shape);
}
}
return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape);
}

// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to
// minimize breakages.
Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ Tensor masked_softmax_backward_cuda(
auto grad = grad_.contiguous();
auto output = output_.contiguous();
auto mask = mask_.contiguous();
int64_t dim = dim_.has_value() ? dim_.value() : output.dim() - 1;
int64_t dim = dim_.has_value() ? maybe_wrap_dim(dim_.value(), output.dim()) : output.dim() - 1;

grad = grad.dim() == 0 ? grad.view(1) : grad;
mask = mask.dim() == 0 ? mask.view(1) : mask;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/metal/MetalTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
return strides_;
}

c10::SymIntArrayRef sym_strides_custom() const override {
return c10::fromIntArrayRefKnownNonNegative(strides_);
}

bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
return true;
}
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4203,7 +4203,7 @@

- func: negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)

- func: repeat(Tensor self, int[] repeats) -> Tensor
- func: repeat(Tensor self, SymInt[] repeats) -> Tensor
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
dispatch:
CompositeExplicitAutograd: repeat
Expand All @@ -4224,18 +4224,18 @@
- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> Tensor
variants: function, method

- func: reshape(Tensor(a) self, int[] shape) -> Tensor(a)
- func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeImplicitAutograd: reshape
CompositeImplicitAutograd: reshape_symint
CompositeImplicitAutogradNestedTensor: reshape_nested

# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
# They are not user-facing, hence the leading underscore. Please don't use it
# anywhere else.
- func: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
- func: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
variants: function, method
device_check: NoCheck
device_guard: False
Expand Down Expand Up @@ -5450,7 +5450,7 @@
tags: dynamic_output_shape
autogen: _unique2.out

- func: _unsafe_view(Tensor self, int[] size) -> Tensor
- func: _unsafe_view(Tensor self, SymInt[] size) -> Tensor
dispatch:
CompositeExplicitAutograd: _unsafe_view
autogen: _unsafe_view.out
Expand Down Expand Up @@ -12792,7 +12792,7 @@
CompositeExplicitAutogradNonFunctional: permute_copy
tags: view_copy

- func: _reshape_alias_copy(Tensor self, int[] size, int[] stride) -> Tensor
- func: _reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor
variants: function
dispatch:
CompositeExplicitAutogradNonFunctional: _reshape_alias_copy
Expand Down Expand Up @@ -13004,7 +13004,7 @@
CompositeExplicitAutograd: permute_copy_out


- func: _reshape_alias_copy.out(Tensor self, int[] size, int[] stride, *, Tensor(a!) out) -> Tensor(a!)
- func: _reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CompositeExplicitAutograd: _reshape_alias_copy_out
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/ts_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ symint:
- slice_scatter
- empty_strided
- new_empty_strided
- _reshape_alias_copy
autograd:
- max_pool3d
- native_group_norm
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
return strides_;
}

SymIntArrayRef sym_strides_custom() const override {
return c10::fromIntArrayRefKnownNonNegative(strides_);
}

bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
return true;
}
Expand Down
2 changes: 0 additions & 2 deletions functorch/test/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition
xfail('ravel', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
xfail('repeat', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('repeat_interleave', ''), # aten.repeat_interleave.Te...
xfail('reshape_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('reshape', ''), # Cannot call numel() on tensor with symbolic sizes/strides
Expand Down Expand Up @@ -1042,7 +1041,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _):
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition
xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('tile', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('topk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,6 @@ def f(a, b, c, d, e):
xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition
xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('reshape', ''), # Tensors of type TensorImpl do not have numel
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('roll', ''), # Tensors of type TensorImpl do not have numel
Expand Down
8 changes: 4 additions & 4 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1324,8 +1324,8 @@
- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
self: renorm_backward(grad, self, p, dim, maxnorm)

- name: repeat(Tensor self, int[] repeats) -> Tensor
self: repeat_backward(grad, repeats, self.sizes())
- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
self: repeat_backward(grad, repeats, self.sym_sizes())
result: auto_linear

- name: special_entr(Tensor self) -> Tensor
Expand All @@ -1352,7 +1352,7 @@
# making it impossible (hard) to detect when it is actually a view.
# - name: reshape(Tensor self, IntArrayRef shape)

- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
- name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
self: grad.reshape(self.sizes())
result: auto_linear

Expand Down Expand Up @@ -1697,7 +1697,7 @@
output_differentiability: [True, False, False]
self: not_implemented("_unique2")

- name: _unsafe_view(Tensor self, int[] size) -> Tensor
- name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor
self: grad.reshape(self.sizes())
result: auto_linear

Expand Down
13 changes: 7 additions & 6 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,11 +1383,11 @@ Tensor renorm_backward(

Tensor repeat_backward(
Tensor grad,
IntArrayRef repeats,
IntArrayRef input_shape) {
c10::SymIntArrayRef repeats,
c10::SymIntArrayRef input_shape) {
auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
if (find_iter != repeats.cend()) {
return at::zeros(input_shape, grad.options());
return at::zeros_symint(input_shape, grad.options());
}
const auto input_dims = input_shape.size();
int64_t num_unsqueezed = grad.dim() - input_dims;
Expand All @@ -1396,9 +1396,10 @@ Tensor repeat_backward(
grad = grad.sum(0, false);
}

at::DimVector grad_size, sum_dims;
at::SymDimVector grad_size;
at::DimVector sum_dims;
for (const auto dim : c10::irange(input_dims)) {
int64_t repeat = repeats[dim + num_unsqueezed];
auto repeat = repeats[dim + num_unsqueezed];
// Reshape gradient (repeat > 1)
// Index: [..., dim , ...] [..., dim , dim+1 , ...]
// Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
Expand Down Expand Up @@ -1457,7 +1458,7 @@ Tensor repeat_backward(
// reduce the whole grad tensor into a scalar rather than keeping original
// dimensions.
if (!sum_dims.empty()) {
grad = grad.reshape(grad_size);
grad = grad.reshape_symint(grad_size);
grad = grad.sum(sum_dims);
}
return grad;
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ at::Tensor renorm_backward(
const at::Scalar& maxnorm);
at::Tensor repeat_backward(
at::Tensor grad,
at::IntArrayRef repeats,
at::IntArrayRef input_shape);
at::SymIntArrayRef repeats,
at::SymIntArrayRef input_shape);
at::Tensor _fused_dropout_backward(
at::Tensor grad,
at::Tensor mask,
Expand Down
Loading

0 comments on commit 3638089

Please sign in to comment.