Skip to content

Commit 5263f0f

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[JAX] [XLA:Python] Move ShapeIndex bindings out of JAX and into XLA.
JAX does not use this class any more. PiperOrigin-RevId: 751584043
1 parent f84e548 commit 5263f0f

File tree

3 files changed

+59
-81
lines changed

3 files changed

+59
-81
lines changed

jaxlib/_jax/__init__.pyi

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
from __future__ import annotations
1717

1818
import builtins
19+
from collections.abc import Callable, Iterator, Mapping, Sequence
1920
import enum
2021
import inspect
2122
import types
2223
from typing import Any, ClassVar, TypeVar, overload
23-
from collections.abc import Callable, Mapping, Iterator, Sequence
2424

2525
import numpy as np
2626

@@ -94,9 +94,12 @@ class Layout:
9494
@overload
9595
def __init__(self, minor_to_major: tuple[int, ...]): ...
9696
@overload
97-
def __init__(self, minor_to_major: tuple[int, ...],
98-
tiling: tuple[tuple[int, ...], ...],
99-
element_size_in_bits: int): ...
97+
def __init__(
98+
self,
99+
minor_to_major: tuple[int, ...],
100+
tiling: tuple[tuple[int, ...], ...],
101+
element_size_in_bits: int,
102+
): ...
100103
def minor_to_major(self) -> tuple[int, ...]: ...
101104
def tiling(self) -> Sequence[tuple[int, ...]]: ...
102105
def element_size_in_bits(self) -> int: ...
@@ -148,13 +151,6 @@ class ProgramShape:
148151
def result_shape(self) -> Shape: ...
149152
def __repr__(self) -> str: ...
150153

151-
class ShapeIndex:
152-
def __init__(self, indices: list[int]) -> None: ...
153-
def __eq__(self, other: Any) -> bool: ...
154-
def __ne__(self, other: Any) -> bool: ...
155-
def __hash__(self) -> int: ...
156-
def __repr__(self) -> str: ...
157-
158154
class Literal:
159155
def __init__(self, shape: Shape) -> None: ...
160156
def __repr__(self) -> str: ...
@@ -253,7 +249,10 @@ class CompileOptions:
253249
env_option_overrides: list[tuple[str, str]]
254250

255251
def register_custom_call_target(
256-
fn_name: str, capsule: Any, platform: str, api_version: int = ...,
252+
fn_name: str,
253+
capsule: Any,
254+
platform: str,
255+
api_version: int = ...,
257256
) -> _Status: ...
258257
def register_custom_call_partitioner(
259258
name: str,
@@ -268,7 +267,6 @@ def register_custom_call_as_batch_partitionable(
268267
target_name: str,
269268
c_api: Any | None = ...,
270269
) -> None: ...
271-
272270
def register_custom_type_id(type_name: str, type_id: Any) -> None: ...
273271

274272
class AutotuneCacheMode(enum.IntEnum):
@@ -346,7 +344,9 @@ class ExecutableBuildOptions:
346344
auto_spmd_partitioning_mesh_shape: list[int]
347345
auto_spmd_partitioning_mesh_ids: list[int]
348346
use_shardy_partitioner: bool
349-
def compilation_environments_from_serialized_proto(self, serialized_proto: bytes) -> None: ...
347+
def compilation_environments_from_serialized_proto(
348+
self, serialized_proto: bytes
349+
) -> None: ...
350350

351351
class OpSharding_Type(enum.IntEnum):
352352
REPLICATED = ...
@@ -402,8 +402,8 @@ class HloSharding:
402402
def unknown() -> HloSharding: ...
403403
@staticmethod
404404
def subgroup_with_device_ordering(
405-
tile_assignment: np.ndarray,
406-
subgroup_types: Sequence[OpSharding_Type]) -> HloSharding: ...
405+
tile_assignment: np.ndarray, subgroup_types: Sequence[OpSharding_Type]
406+
) -> HloSharding: ...
407407
def __eq__(self, other: Any) -> bool: ...
408408
def __hash__(self) -> int: ...
409409
def __repr__(self) -> str: ...
@@ -549,7 +549,6 @@ class MpiCollectives(CpuCollectives):
549549
def Finalize(self): ...
550550

551551
def make_mpi_collectives() -> MpiCollectives: ...
552-
553552
def get_tfrt_cpu_client(
554553
asynchronous: bool = ...,
555554
distributed_client: DistributedRuntimeClient | None = ...,
@@ -593,7 +592,9 @@ def get_c_api_topology(
593592
options: dict[str, str | int | list[int] | float],
594593
) -> DeviceTopology: ...
595594
def get_topology_for_devices(devices: list[Device]) -> DeviceTopology: ...
596-
def load_pjrt_plugin(platform_name: str, library_path: str | None, c_api: Any | None) -> _Status: ...
595+
def load_pjrt_plugin(
596+
platform_name: str, library_path: str | None, c_api: Any | None
597+
) -> _Status: ...
597598
def pjrt_plugin_loaded(plugin_name: str) -> bool: ...
598599
def pjrt_plugin_initialized(plugin_name: str) -> bool: ...
599600
def initialize_pjrt_plugin(platform_name: str) -> _Status: ...
@@ -634,23 +635,19 @@ def batched_copy_array_to_devices_with_sharding(
634635
sharding: Sequence[Any],
635636
array_copy_semantics: Sequence[ArrayCopySemantics],
636637
) -> Sequence[ArrayImpl]: ...
637-
638638
def batched_block_until_ready(x: Sequence[ArrayImpl]) -> None: ...
639-
640639
def batched_device_put(
641640
aval: Any,
642641
sharding: Any,
643642
shards: Sequence[Any],
644643
devices: list[Device],
645644
committed: bool = True,
646645
) -> ArrayImpl: ...
647-
648646
def reorder_shards(
649647
x: ArrayImpl,
650648
dst_sharding: Any,
651649
array_copy_semantics: ArrayCopySemantics,
652650
) -> ArrayImpl: ...
653-
654651
def check_and_canonicalize_memory_kind(
655652
memory_kind: str | None, device_list: DeviceList
656653
) -> str | None: ...
@@ -724,18 +721,23 @@ def dlpack_managed_tensor_to_buffer(
724721
tensor: Any, device: Device, stream: int | None
725722
) -> ArrayImpl: ...
726723
@overload
727-
def dlpack_managed_tensor_to_buffer( # Legacy overload
724+
def dlpack_managed_tensor_to_buffer( # Legacy overload
728725
tensor: Any,
729726
cpu_backend: Client | None = ...,
730727
gpu_backend: Client | None = ...,
731728
) -> ArrayImpl: ...
732-
733729
def cuda_array_interface_to_buffer(
734-
cai: dict[str, (
735-
str | int | None |
736-
tuple[int, ...] | tuple[int, bool] |
737-
list[tuple[str, str]] |
738-
list[tuple[str, str, tuple[int, ...]]])
730+
cai: dict[
731+
str,
732+
(
733+
str
734+
| int
735+
| None
736+
| tuple[int, ...]
737+
| tuple[int, bool]
738+
| list[tuple[str, str]]
739+
| list[tuple[str, str, tuple[int, ...]]]
740+
),
739741
],
740742
gpu_backend: Client | None = ...,
741743
device_id: int | None = None,
@@ -748,11 +750,13 @@ class Frame:
748750
function_name: str
749751
function_line_start: int
750752
line_num: int
751-
def __init__(self,
752-
file_name: str,
753-
function_name: str,
754-
function_line_start: int,
755-
line_num: int): ...
753+
def __init__(
754+
self,
755+
file_name: str,
756+
function_name: str,
757+
function_line_start: int,
758+
line_num: int,
759+
): ...
756760
def __repr__(self) -> str: ...
757761

758762
class Traceback:
@@ -790,13 +794,19 @@ class DistributedRuntimeClient:
790794
def key_value_try_get_bytes(self, key: str) -> _Status: ...
791795
def key_value_dir_get(self, key: str) -> _Status: ...
792796
def key_value_dir_get_bytes(self, key: str) -> _Status: ...
793-
def key_value_set(self, key: str, value: str,
794-
allow_overwrite: bool = False) -> _Status: ...
795-
def key_value_set_bytes(self, key: str, value: bytes,
796-
allow_overwrite: bool = False) -> _Status: ...
797+
def key_value_set(
798+
self, key: str, value: str, allow_overwrite: bool = False
799+
) -> _Status: ...
800+
def key_value_set_bytes(
801+
self, key: str, value: bytes, allow_overwrite: bool = False
802+
) -> _Status: ...
797803
def key_value_delete(self, key: str) -> _Status: ...
798-
def wait_at_barrier(self, barrier_id: str, timeout_in_ms: int,
799-
process_ids: list[int] | None = None) -> _Status: ...
804+
def wait_at_barrier(
805+
self,
806+
barrier_id: str,
807+
timeout_in_ms: int,
808+
process_ids: list[int] | None = None,
809+
) -> _Status: ...
800810
def get_live_nodes(self, process_ids: list[int]) -> _Status: ...
801811

802812
def get_distributed_runtime_service(
@@ -970,22 +980,25 @@ def is_tsan() -> bool: ...
970980
def is_sanitized() -> bool: ...
971981

972982
class TransferConnection:
973-
974983
def address(self) -> str: ...
975-
976984
def _pull_flat(self, uuid, backend, avals_flat) -> list[Any]: ...
977985

978986
class TransferServer:
979987
def _await_pull_flat(self, uuid, args: list[ArrayImpl]): ...
980-
981988
def connect(self, address: str) -> TransferConnection: ...
982989

983-
def start_transfer_server(client: Client, address: str = "", transport_addresses: list[str] = [], max_num_parallel_copies: int = 0, transfer_size: int = 0) -> TransferServer: ...
984-
990+
def start_transfer_server(
991+
client: Client,
992+
address: str = "",
993+
transport_addresses: list[str] = [],
994+
max_num_parallel_copies: int = 0,
995+
transfer_size: int = 0,
996+
) -> TransferServer: ...
985997
def approx_top_k_reduction_output_size(
986998
input_size: int,
987999
rank: int,
9881000
top_k: int,
9891001
recall_target: float,
9901002
aggregate_to_topk: bool | None = ...,
991-
input_size_override: int | None = ...) -> tuple[int, int]: ...
1003+
input_size_override: int | None = ...,
1004+
) -> tuple[int, int]: ...

jaxlib/xla_client.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -251,28 +251,6 @@ def result_shape(self) -> Shape:
251251
def __repr__(self):
252252
"""
253253

254-
ShapeIndex = _xla.ShapeIndex
255-
ShapeIndex.__doc__ = """
256-
A Shape is an object defined in C++ that duck types like the following class:
257-
258-
class ShapeIndex:
259-
'''Represents an XLA ShapeIndex.
260-
261-
An index for specifying a particular nested subshape within a shape. Used in
262-
ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through
263-
the Shape tree where each element of ShapeIndex indexes into a tuple (or
264-
nested tuple) within the shape. For a non-nested tuple, an index has a single
265-
element.
266-
'''
267-
268-
def __init__(self, List[int]) -> ShapeIndex:
269-
def __eq__(self, other: Shape) -> bool:
270-
def __ne__(self, other: Shape) -> bool:
271-
def __hash__(self):
272-
def __repr__(self):
273-
"""
274-
275-
276254
DeviceAssignment = _xla.DeviceAssignment
277255
DeviceAssignment.__doc__ = """
278256
A DeviceAssignment is a C++ object with the following signature.

jaxlib/xla_compiler.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -647,19 +647,6 @@ void BuildXlaCompilerSubmodule(nb::module_& m) {
647647
.def("result_shape", &ProgramShape::result)
648648
.def("__repr__", &ProgramShape::ToString);
649649

650-
nb::class_<ShapeIndex>(m, "ShapeIndex")
651-
.def("__init__",
652-
[](ShapeIndex* self, const std::vector<int64_t>& v) {
653-
new (self) ShapeIndex(v.begin(), v.end());
654-
})
655-
.def("__repr__", &ShapeIndex::ToString)
656-
.def("__eq__", [](const ShapeIndex& shape_ind,
657-
const ShapeIndex& other) { return shape_ind == other; })
658-
.def("__ne__", [](const ShapeIndex& shape_ind,
659-
const ShapeIndex& other) { return shape_ind != other; })
660-
.def("__hash__",
661-
[](const ShapeIndex& shape_ind) { return absl::HashOf(shape_ind); });
662-
663650
// Literals
664651
nb::class_<Literal>(m, "Literal")
665652
.def(nb::init<const Shape&>())

0 commit comments

Comments
 (0)