Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:Python] Modify DLPack behavior with unit dimensions. #19327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[XLA:Python] Modify DLPack behavior with unit dimensions.
As discovered in jax-ml/jax#24680, when a PyTorch tensor has a dimension with size `1`, it seems to report the DLPack stride for that dimension as `1`. This means that even when the torch Tensor is formally row-major, the imported array isn't. This shouldn't really matter (the placement of unit dimensions can be arbitrary!), but in practice (since XLA:CPU ignores layouts - that's another issue that is being worked on!) it can be annoying. This change updates the behavior to always produce row-major layouts for unit dimensions wrt to their neighbors.

PiperOrigin-RevId: 696341186
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 25, 2024
commit 9e13d13159a6637156d173aeeaf0df92b4001394
28 changes: 28 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,33 @@ cc_library(
),
)

cc_library(
name = "dlpack_strides",
srcs = ["dlpack_strides.cc"],
hdrs = ["dlpack_strides.h"],
deps = [
"//xla:util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "dlpack_strides_test",
srcs = ["dlpack_strides_test.cc"],
deps = [
":dlpack_strides",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "dlpack",
srcs = ["dlpack.cc"],
Expand All @@ -525,6 +552,7 @@ cc_library(
],
features = ["-use_header_modules"],
deps = [
":dlpack_strides",
":nb_class_ptr",
":py_client",
":python_ref_manager",
Expand Down
32 changes: 1 addition & 31 deletions xla/python/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/dlpack_strides.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/nb_class_ptr.h"
Expand Down Expand Up @@ -212,37 +213,6 @@ absl::StatusOr<PrimitiveType> DLDataTypeToPrimitiveType(DLDataType type) {
}
}

absl::StatusOr<std::vector<int64_t>> StridesToLayout(
absl::Span<int64_t const> dims, absl::Span<int64_t const> strides) {
CHECK_EQ(dims.size(), strides.size());
std::vector<int64_t> minor_to_major(dims.size());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
absl::c_sort(minor_to_major, [&](int a, int b) {
if (strides[a] < strides[b]) {
return true;
}
if (strides[a] > strides[b]) {
return false;
}
// If two dimensions have the same stride, prefer the major-to-minor
// interpretation of the ordering, since that's what JAX wants.
return b < a;
});
int64_t stride = 1;
for (int64_t d : minor_to_major) {
if (dims[d] > 1 && strides[d] != stride) {
return Unimplemented(
"Only DLPack tensors with trivial (compact) striding are supported; "
"i.e., tensors whose striding represents a transposition of the "
"underlying buffer but not broadcasting. Dimensions were: [%s], "
"strides were [%s].",
absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
}
stride *= dims[d];
}
return minor_to_major;
}

absl::StatusOr<DLDeviceType> DLDeviceTypeForDevice(const PjRtDevice& device) {
if (device.client()->platform_id() == CpuId()) {
return kDLCPU;
Expand Down
98 changes: 98 additions & 0 deletions xla/python/dlpack_strides.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/dlpack_strides.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <numeric>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/optimization.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"

namespace xla {

absl::StatusOr<std::vector<int64_t>> HandleUnitDimensions(
absl::Span<int64_t const> dims, absl::Span<int64_t const> strides) {
int64_t stride = 1;
for (std::size_t i = dims.size(); i > 0; --i) {
if (dims[i - 1] > 1) {
if (strides[i - 1] < stride) {
return Unimplemented("Not row-major.");
}
stride = strides[i - 1];
}
}
std::vector<int64_t> minor_to_major(dims.size());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
std::reverse(minor_to_major.begin(), minor_to_major.end());
return minor_to_major;
}

absl::StatusOr<std::vector<int64_t>> StridesToLayout(
absl::Span<int64_t const> dims, absl::Span<int64_t const> strides) {
CHECK_EQ(dims.size(), strides.size());
if (dims.empty()) {
return std::vector<int64_t>();
}

// A special case: if any dimension has size 1, then the stride in that
// dimension is arbitrary. If all the other dimensions are row-major, then
// we choose to return the full row-major layout.
if (ABSL_PREDICT_FALSE(
absl::c_any_of(dims, [](int64_t d) { return d <= 1; }))) {
auto maybe_minor_to_major = HandleUnitDimensions(dims, strides);
if (maybe_minor_to_major.ok()) {
return maybe_minor_to_major.value();
}
}

std::vector<int64_t> minor_to_major(dims.size());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
absl::c_sort(minor_to_major, [&](int a, int b) {
if (strides[a] < strides[b]) {
return true;
}
if (strides[a] > strides[b]) {
return false;
}
// If two dimensions have the same stride, prefer the major-to-minor
// interpretation of the ordering, since that's what JAX wants.
return b < a;
});

int64_t stride = 1;
for (int64_t d : minor_to_major) {
if (dims[d] > 1 && strides[d] != stride) {
return Unimplemented(
"Only DLPack tensors with trivial (compact) striding are supported; "
"i.e., tensors whose striding represents a transposition of the "
"underlying buffer but not broadcasting. Dimensions were: [%s], "
"strides were [%s].",
absl::StrJoin(dims, ","), absl::StrJoin(strides, ","));
}
stride *= dims[d];
}
return minor_to_major;
}

} // namespace xla
32 changes: 32 additions & 0 deletions xla/python/dlpack_strides.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_DLPACK_STRIDES_H_
#define XLA_PYTHON_DLPACK_STRIDES_H_

#include <cstdint>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/types/span.h"

namespace xla {

absl::StatusOr<std::vector<int64_t>> StridesToLayout(
absl::Span<int64_t const> dims, absl::Span<int64_t const> strides);

} // namespace xla

#endif // XLA_PYTHON_DLPACK_STRIDES_H_
59 changes: 59 additions & 0 deletions xla/python/dlpack_strides_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/dlpack_strides.h"

#include <cstdint>
#include <tuple>
#include <vector>

#include "absl/types/span.h"
#include "tsl/platform/test.h"

namespace xla {
namespace {

typedef std::tuple<std::vector<int64_t>, std::vector<int64_t>,
std::vector<int64_t>>
StridesToLayoutTestCase;

class DlpackStridesTestSuite
: public testing::TestWithParam<StridesToLayoutTestCase> {};

TEST_P(DlpackStridesTestSuite, StridesToLayout) {
auto [dims, strides, expected_layout] = GetParam();
auto layout = StridesToLayout(absl::MakeSpan(dims), absl::MakeSpan(strides));
EXPECT_TRUE(layout.ok());
EXPECT_EQ(layout.value(), expected_layout);
}

INSTANTIATE_TEST_SUITE_P(StridesToLayout, DlpackStridesTestSuite,
testing::ValuesIn<StridesToLayoutTestCase>({
{{}, {}, {}},
{{2, 3, 4}, {12, 4, 1}, {2, 1, 0}},
{{2, 3, 4}, {1, 2, 6}, {0, 1, 2}},
{{2, 1, 3, 4}, {12, 12, 4, 1}, {3, 2, 1, 0}},
{{2, 1, 3, 4}, {12, 1, 4, 1}, {3, 2, 1, 0}},
{{1, 1}, {1, 100}, {1, 0}},
{{1, 1, 4}, {1, 100, 1}, {2, 1, 0}},
{{4, 1, 1}, {1, 100, 1}, {2, 1, 0}},
// When there is a unit dimension, but the other
// strides are not row-major, we choose to make
// the layout as close to row-major as possible.
{{2, 1, 3, 4}, {1, 2, 2, 6}, {0, 2, 1, 3}},
}));

} // namespace
} // namespace xla