[XLA:Python] Modify DLPack behavior with unit dimensions. #19327
+218
−31
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[XLA:Python] Modify DLPack behavior with unit dimensions.
As discovered in jax-ml/jax#24680, when a PyTorch tensor has a dimension with size
1
, it seems to report the DLPack stride for that dimension as1
. This means that even when the torch Tensor is formally row-major, the imported array isn't. This shouldn't really matter (the placement of unit dimensions can be arbitrary!), but in practice (since XLA:CPU ignores layouts - that's another issue that is being worked on!) it can be annoying. This change updates the behavior to always produce row-major layouts for unit dimensions wrt to their neighbors.