From 9e13d13159a6637156d173aeeaf0df92b4001394 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Wed, 13 Nov 2024 18:28:54 -0800 Subject: [PATCH] [XLA:Python] Modify DLPack behavior with unit dimensions. As discovered in https://github.com/jax-ml/jax/issues/24680, when a PyTorch tensor has a dimension with size `1`, it seems to report the DLPack stride for that dimension as `1`. This means that even when the torch Tensor is formally row-major, the imported array isn't. This shouldn't really matter (the placement of unit dimensions can be arbitrary!), but in practice (since XLA:CPU ignores layouts - that's another issue that is being worked on!) it can be annoying. This change updates the behavior to always produce row-major layouts for unit dimensions wrt to their neighbors. PiperOrigin-RevId: 696341186 --- xla/python/BUILD | 28 +++++++++ xla/python/dlpack.cc | 32 +--------- xla/python/dlpack_strides.cc | 98 +++++++++++++++++++++++++++++++ xla/python/dlpack_strides.h | 32 ++++++++++ xla/python/dlpack_strides_test.cc | 59 +++++++++++++++++++ 5 files changed, 218 insertions(+), 31 deletions(-) create mode 100644 xla/python/dlpack_strides.cc create mode 100644 xla/python/dlpack_strides.h create mode 100644 xla/python/dlpack_strides_test.cc diff --git a/xla/python/BUILD b/xla/python/BUILD index fbcc5dc424850..bfc9737d1544f 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -514,6 +514,33 @@ cc_library( ), ) +cc_library( + name = "dlpack_strides", + srcs = ["dlpack_strides.cc"], + hdrs = ["dlpack_strides.h"], + deps = [ + "//xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "dlpack_strides_test", + srcs = ["dlpack_strides_test.cc"], + deps = [ + ":dlpack_strides", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "dlpack", srcs = ["dlpack.cc"], @@ -525,6 +552,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":dlpack_strides", ":nb_class_ptr", ":py_client", ":python_ref_manager", diff --git a/xla/python/dlpack.cc b/xla/python/dlpack.cc index a4bf30dbfb73b..9b99ab6ce355a 100644 --- a/xla/python/dlpack.cc +++ b/xla/python/dlpack.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_layout.h" +#include "xla/python/dlpack_strides.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" @@ -212,37 +213,6 @@ absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { } } -absl::StatusOr> StridesToLayout( - absl::Span dims, absl::Span strides) { - CHECK_EQ(dims.size(), strides.size()); - std::vector minor_to_major(dims.size()); - std::iota(minor_to_major.begin(), minor_to_major.end(), 0); - absl::c_sort(minor_to_major, [&](int a, int b) { - if (strides[a] < strides[b]) { - return true; - } - if (strides[a] > strides[b]) { - return false; - } - // If two dimensions have the same stride, prefer the major-to-minor - // interpretation of the ordering, since that's what JAX wants. - return b < a; - }); - int64_t stride = 1; - for (int64_t d : minor_to_major) { - if (dims[d] > 1 && strides[d] != stride) { - return Unimplemented( - "Only DLPack tensors with trivial (compact) striding are supported; " - "i.e., tensors whose striding represents a transposition of the " - "underlying buffer but not broadcasting. Dimensions were: [%s], " - "strides were [%s].", - absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); - } - stride *= dims[d]; - } - return minor_to_major; -} - absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { if (device.client()->platform_id() == CpuId()) { return kDLCPU; diff --git a/xla/python/dlpack_strides.cc b/xla/python/dlpack_strides.cc new file mode 100644 index 0000000000000..b148b18e912a0 --- /dev/null +++ b/xla/python/dlpack_strides.cc @@ -0,0 +1,98 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/dlpack_strides.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/optimization.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" + +namespace xla { + +absl::StatusOr> HandleUnitDimensions( + absl::Span dims, absl::Span strides) { + int64_t stride = 1; + for (std::size_t i = dims.size(); i > 0; --i) { + if (dims[i - 1] > 1) { + if (strides[i - 1] < stride) { + return Unimplemented("Not row-major."); + } + stride = strides[i - 1]; + } + } + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + std::reverse(minor_to_major.begin(), minor_to_major.end()); + return minor_to_major; +} + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + if (dims.empty()) { + return std::vector(); + } + + // A special case: if any dimension has size 1, then the stride in that + // dimension is arbitrary. If all the other dimensions are row-major, then + // we choose to return the full row-major layout. + if (ABSL_PREDICT_FALSE( + absl::c_any_of(dims, [](int64_t d) { return d <= 1; }))) { + auto maybe_minor_to_major = HandleUnitDimensions(dims, strides); + if (maybe_minor_to_major.ok()) { + return maybe_minor_to_major.value(); + } + } + + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides[a] < strides[b]) { + return true; + } + if (strides[a] > strides[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +} // namespace xla diff --git a/xla/python/dlpack_strides.h b/xla/python/dlpack_strides.h new file mode 100644 index 0000000000000..2e08c061fc62c --- /dev/null +++ b/xla/python/dlpack_strides.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_DLPACK_STRIDES_H_ +#define XLA_PYTHON_DLPACK_STRIDES_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" + +namespace xla { + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides); + +} // namespace xla + +#endif // XLA_PYTHON_DLPACK_STRIDES_H_ diff --git a/xla/python/dlpack_strides_test.cc b/xla/python/dlpack_strides_test.cc new file mode 100644 index 0000000000000..dce7836592ee7 --- /dev/null +++ b/xla/python/dlpack_strides_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/dlpack_strides.h" + +#include +#include +#include + +#include "absl/types/span.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +typedef std::tuple, std::vector, + std::vector> + StridesToLayoutTestCase; + +class DlpackStridesTestSuite + : public testing::TestWithParam {}; + +TEST_P(DlpackStridesTestSuite, StridesToLayout) { + auto [dims, strides, expected_layout] = GetParam(); + auto layout = StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides)); + EXPECT_TRUE(layout.ok()); + EXPECT_EQ(layout.value(), expected_layout); +} + +INSTANTIATE_TEST_SUITE_P(StridesToLayout, DlpackStridesTestSuite, + testing::ValuesIn({ + {{}, {}, {}}, + {{2, 3, 4}, {12, 4, 1}, {2, 1, 0}}, + {{2, 3, 4}, {1, 2, 6}, {0, 1, 2}}, + {{2, 1, 3, 4}, {12, 12, 4, 1}, {3, 2, 1, 0}}, + {{2, 1, 3, 4}, {12, 1, 4, 1}, {3, 2, 1, 0}}, + {{1, 1}, {1, 100}, {1, 0}}, + {{1, 1, 4}, {1, 100, 1}, {2, 1, 0}}, + {{4, 1, 1}, {1, 100, 1}, {2, 1, 0}}, + // When there is a unit dimension, but the other + // strides are not row-major, we choose to make + // the layout as close to row-major as possible. + {{2, 1, 3, 4}, {1, 2, 2, 6}, {0, 2, 1, 3}}, + })); + +} // namespace +} // namespace xla