diff --git a/xla/python/BUILD b/xla/python/BUILD index 1d45df59aaf0ad..4d6c04f88c7b95 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -513,6 +513,32 @@ cc_library( ), ) +cc_library( + name = "dlpack_strides", + srcs = ["dlpack_strides.cc"], + hdrs = ["dlpack_strides.h"], + deps = [ + "//xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "dlpack_strides_test", + srcs = ["dlpack_strides_test.cc"], + deps = [ + ":dlpack_strides", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "dlpack", srcs = ["dlpack.cc"], @@ -524,6 +550,7 @@ cc_library( ], features = ["-use_header_modules"], deps = [ + ":dlpack_strides", ":nb_class_ptr", ":py_client", ":python_ref_manager", diff --git a/xla/python/dlpack.cc b/xla/python/dlpack.cc index a4bf30dbfb73bc..9b99ab6ce355a3 100644 --- a/xla/python/dlpack.cc +++ b/xla/python/dlpack.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_layout.h" +#include "xla/python/dlpack_strides.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" @@ -212,37 +213,6 @@ absl::StatusOr DLDataTypeToPrimitiveType(DLDataType type) { } } -absl::StatusOr> StridesToLayout( - absl::Span dims, absl::Span strides) { - CHECK_EQ(dims.size(), strides.size()); - std::vector minor_to_major(dims.size()); - std::iota(minor_to_major.begin(), minor_to_major.end(), 0); - absl::c_sort(minor_to_major, [&](int a, int b) { - if (strides[a] < strides[b]) { - return true; - } - if (strides[a] > strides[b]) { - return false; - } - // If two dimensions have the same stride, prefer the major-to-minor - // interpretation of the ordering, since that's what JAX wants. - return b < a; - }); - int64_t stride = 1; - for (int64_t d : minor_to_major) { - if (dims[d] > 1 && strides[d] != stride) { - return Unimplemented( - "Only DLPack tensors with trivial (compact) striding are supported; " - "i.e., tensors whose striding represents a transposition of the " - "underlying buffer but not broadcasting. Dimensions were: [%s], " - "strides were [%s].", - absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); - } - stride *= dims[d]; - } - return minor_to_major; -} - absl::StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { if (device.client()->platform_id() == CpuId()) { return kDLCPU; diff --git a/xla/python/dlpack_strides.cc b/xla/python/dlpack_strides.cc new file mode 100644 index 00000000000000..b424d35a186a3e --- /dev/null +++ b/xla/python/dlpack_strides.cc @@ -0,0 +1,78 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/dlpack_strides.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/util.h" +#include "tsl/platform/logging.h" + +namespace xla { + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides) { + CHECK_EQ(dims.size(), strides.size()); + + // Handle unit dimensions by inserting the previous stride. This has the + // effect of always producing row-major layouts for unit dimensions, which + // isn't strictly necessary, but is convenient since XLA defaults to row-major + // layouts. + std::vector strides_(strides.size()); + for (int64_t i = 0; i < strides.size(); ++i) { + if (i == 0 || dims[i] > 1) { + strides_[i] = strides[i]; + } else { + strides_[i] = strides[i - 1]; + } + } + + std::vector minor_to_major(dims.size()); + std::iota(minor_to_major.begin(), minor_to_major.end(), 0); + absl::c_sort(minor_to_major, [&](int a, int b) { + if (strides_[a] < strides_[b]) { + return true; + } + if (strides_[a] > strides_[b]) { + return false; + } + // If two dimensions have the same stride, prefer the major-to-minor + // interpretation of the ordering, since that's what JAX wants. + return b < a; + }); + int64_t stride = 1; + for (int64_t d : minor_to_major) { + if (dims[d] > 1 && strides[d] != stride) { + return Unimplemented( + "Only DLPack tensors with trivial (compact) striding are supported; " + "i.e., tensors whose striding represents a transposition of the " + "underlying buffer but not broadcasting. Dimensions were: [%s], " + "strides were [%s].", + absl::StrJoin(dims, ","), absl::StrJoin(strides, ",")); + } + stride *= dims[d]; + } + return minor_to_major; +} + +} // namespace xla diff --git a/xla/python/dlpack_strides.h b/xla/python/dlpack_strides.h new file mode 100644 index 00000000000000..2e08c061fc62c5 --- /dev/null +++ b/xla/python/dlpack_strides.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_DLPACK_STRIDES_H_ +#define XLA_PYTHON_DLPACK_STRIDES_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" + +namespace xla { + +absl::StatusOr> StridesToLayout( + absl::Span dims, absl::Span strides); + +} // namespace xla + +#endif // XLA_PYTHON_DLPACK_STRIDES_H_ diff --git a/xla/python/dlpack_strides_test.cc b/xla/python/dlpack_strides_test.cc new file mode 100644 index 00000000000000..b05c4d3708f2c8 --- /dev/null +++ b/xla/python/dlpack_strides_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/dlpack_strides.h" + +#include +#include + +#include "absl/types/span.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +TEST(DlpackStridesTest, basic) { + std::vector dims = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + auto layout = StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides)); + EXPECT_TRUE(layout.ok()); + EXPECT_EQ(layout.value(), std::vector({2, 1, 0})); + + std::vector strides_cm = {1, 2, 6}; + auto layout_cm = + StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides_cm)); + EXPECT_TRUE(layout_cm.ok()); + EXPECT_EQ(layout_cm.value(), std::vector({0, 1, 2})); +} + +TEST(DlpackStridesTest, unitDim) { + // Row-major + std::vector dims = {2, 1, 3, 4}; + std::vector strides = {12, 12, 4, 1}; + auto layout = StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides)); + EXPECT_TRUE(layout.ok()); + EXPECT_EQ(layout.value(), std::vector({3, 2, 1, 0})); + + std::vector strides2 = {12, 1, 4, 1}; + auto layout2 = + StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides2)); + EXPECT_TRUE(layout2.ok()); + EXPECT_EQ(layout2.value(), std::vector({3, 2, 1, 0})); + + // Column-major. Note that in these cases, since one of the dimensions is 1, + // there are several valid layouts that we could produce. We choose to prefer + // row-major whenever there are multiple valid layouts, so the output layouts + // here aren't completely column-major. + std::vector strides_cm = {1, 2, 2, 6}; + auto layout_cm = + StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides_cm)); + EXPECT_TRUE(layout_cm.ok()); + EXPECT_EQ(layout_cm.value(), std::vector({1, 0, 2, 3})); + + std::vector strides2_cm = {1, 1, 2, 6}; + auto layout2_cm = + StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides2_cm)); + EXPECT_TRUE(layout2_cm.ok()); + EXPECT_EQ(layout2_cm.value(), std::vector({1, 0, 2, 3})); +} + +} // namespace +} // namespace xla