1616from __future__ import annotations
1717
1818import builtins
19+ from collections .abc import Callable , Iterator , Mapping , Sequence
1920import enum
2021import inspect
2122import types
2223from typing import Any , ClassVar , TypeVar , overload
23- from collections .abc import Callable , Mapping , Iterator , Sequence
2424
2525import numpy as np
2626
@@ -94,9 +94,12 @@ class Layout:
9494 @overload
9595 def __init__ (self , minor_to_major : tuple [int , ...]): ...
9696 @overload
97- def __init__ (self , minor_to_major : tuple [int , ...],
98- tiling : tuple [tuple [int , ...], ...],
99- element_size_in_bits : int ): ...
97+ def __init__ (
98+ self ,
99+ minor_to_major : tuple [int , ...],
100+ tiling : tuple [tuple [int , ...], ...],
101+ element_size_in_bits : int ,
102+ ): ...
100103 def minor_to_major (self ) -> tuple [int , ...]: ...
101104 def tiling (self ) -> Sequence [tuple [int , ...]]: ...
102105 def element_size_in_bits (self ) -> int : ...
@@ -148,13 +151,6 @@ class ProgramShape:
148151 def result_shape (self ) -> Shape : ...
149152 def __repr__ (self ) -> str : ...
150153
151- class ShapeIndex :
152- def __init__ (self , indices : list [int ]) -> None : ...
153- def __eq__ (self , other : Any ) -> bool : ...
154- def __ne__ (self , other : Any ) -> bool : ...
155- def __hash__ (self ) -> int : ...
156- def __repr__ (self ) -> str : ...
157-
158154class Literal :
159155 def __init__ (self , shape : Shape ) -> None : ...
160156 def __repr__ (self ) -> str : ...
@@ -253,7 +249,10 @@ class CompileOptions:
253249 env_option_overrides : list [tuple [str , str ]]
254250
255251def register_custom_call_target (
256- fn_name : str , capsule : Any , platform : str , api_version : int = ...,
252+ fn_name : str ,
253+ capsule : Any ,
254+ platform : str ,
255+ api_version : int = ...,
257256) -> _Status : ...
258257def register_custom_call_partitioner (
259258 name : str ,
@@ -268,7 +267,6 @@ def register_custom_call_as_batch_partitionable(
268267 target_name : str ,
269268 c_api : Any | None = ...,
270269) -> None : ...
271-
272270def register_custom_type_id (type_name : str , type_id : Any ) -> None : ...
273271
274272class AutotuneCacheMode (enum .IntEnum ):
@@ -346,7 +344,9 @@ class ExecutableBuildOptions:
346344 auto_spmd_partitioning_mesh_shape : list [int ]
347345 auto_spmd_partitioning_mesh_ids : list [int ]
348346 use_shardy_partitioner : bool
349- def compilation_environments_from_serialized_proto (self , serialized_proto : bytes ) -> None : ...
347+ def compilation_environments_from_serialized_proto (
348+ self , serialized_proto : bytes
349+ ) -> None : ...
350350
351351class OpSharding_Type (enum .IntEnum ):
352352 REPLICATED = ...
@@ -402,8 +402,8 @@ class HloSharding:
402402 def unknown () -> HloSharding : ...
403403 @staticmethod
404404 def subgroup_with_device_ordering (
405- tile_assignment : np .ndarray ,
406- subgroup_types : Sequence [ OpSharding_Type ] ) -> HloSharding : ...
405+ tile_assignment : np .ndarray , subgroup_types : Sequence [ OpSharding_Type ]
406+ ) -> HloSharding : ...
407407 def __eq__ (self , other : Any ) -> bool : ...
408408 def __hash__ (self ) -> int : ...
409409 def __repr__ (self ) -> str : ...
@@ -549,7 +549,6 @@ class MpiCollectives(CpuCollectives):
549549 def Finalize (self ): ...
550550
551551def make_mpi_collectives () -> MpiCollectives : ...
552-
553552def get_tfrt_cpu_client (
554553 asynchronous : bool = ...,
555554 distributed_client : DistributedRuntimeClient | None = ...,
@@ -593,7 +592,9 @@ def get_c_api_topology(
593592 options : dict [str , str | int | list [int ] | float ],
594593) -> DeviceTopology : ...
595594def get_topology_for_devices (devices : list [Device ]) -> DeviceTopology : ...
596- def load_pjrt_plugin (platform_name : str , library_path : str | None , c_api : Any | None ) -> _Status : ...
595+ def load_pjrt_plugin (
596+ platform_name : str , library_path : str | None , c_api : Any | None
597+ ) -> _Status : ...
597598def pjrt_plugin_loaded (plugin_name : str ) -> bool : ...
598599def pjrt_plugin_initialized (plugin_name : str ) -> bool : ...
599600def initialize_pjrt_plugin (platform_name : str ) -> _Status : ...
@@ -634,23 +635,19 @@ def batched_copy_array_to_devices_with_sharding(
634635 sharding : Sequence [Any ],
635636 array_copy_semantics : Sequence [ArrayCopySemantics ],
636637) -> Sequence [ArrayImpl ]: ...
637-
638638def batched_block_until_ready (x : Sequence [ArrayImpl ]) -> None : ...
639-
640639def batched_device_put (
641640 aval : Any ,
642641 sharding : Any ,
643642 shards : Sequence [Any ],
644643 devices : list [Device ],
645644 committed : bool = True ,
646645) -> ArrayImpl : ...
647-
648646def reorder_shards (
649647 x : ArrayImpl ,
650648 dst_sharding : Any ,
651649 array_copy_semantics : ArrayCopySemantics ,
652650) -> ArrayImpl : ...
653-
654651def check_and_canonicalize_memory_kind (
655652 memory_kind : str | None , device_list : DeviceList
656653) -> str | None : ...
@@ -724,18 +721,23 @@ def dlpack_managed_tensor_to_buffer(
724721 tensor : Any , device : Device , stream : int | None
725722) -> ArrayImpl : ...
726723@overload
727- def dlpack_managed_tensor_to_buffer ( # Legacy overload
724+ def dlpack_managed_tensor_to_buffer ( # Legacy overload
728725 tensor : Any ,
729726 cpu_backend : Client | None = ...,
730727 gpu_backend : Client | None = ...,
731728) -> ArrayImpl : ...
732-
733729def cuda_array_interface_to_buffer (
734- cai : dict [str , (
735- str | int | None |
736- tuple [int , ...] | tuple [int , bool ] |
737- list [tuple [str , str ]] |
738- list [tuple [str , str , tuple [int , ...]]])
730+ cai : dict [
731+ str ,
732+ (
733+ str
734+ | int
735+ | None
736+ | tuple [int , ...]
737+ | tuple [int , bool ]
738+ | list [tuple [str , str ]]
739+ | list [tuple [str , str , tuple [int , ...]]]
740+ ),
739741 ],
740742 gpu_backend : Client | None = ...,
741743 device_id : int | None = None ,
@@ -748,11 +750,13 @@ class Frame:
748750 function_name : str
749751 function_line_start : int
750752 line_num : int
751- def __init__ (self ,
752- file_name : str ,
753- function_name : str ,
754- function_line_start : int ,
755- line_num : int ): ...
753+ def __init__ (
754+ self ,
755+ file_name : str ,
756+ function_name : str ,
757+ function_line_start : int ,
758+ line_num : int ,
759+ ): ...
756760 def __repr__ (self ) -> str : ...
757761
758762class Traceback :
@@ -790,13 +794,19 @@ class DistributedRuntimeClient:
790794 def key_value_try_get_bytes (self , key : str ) -> _Status : ...
791795 def key_value_dir_get (self , key : str ) -> _Status : ...
792796 def key_value_dir_get_bytes (self , key : str ) -> _Status : ...
793- def key_value_set (self , key : str , value : str ,
794- allow_overwrite : bool = False ) -> _Status : ...
795- def key_value_set_bytes (self , key : str , value : bytes ,
796- allow_overwrite : bool = False ) -> _Status : ...
797+ def key_value_set (
798+ self , key : str , value : str , allow_overwrite : bool = False
799+ ) -> _Status : ...
800+ def key_value_set_bytes (
801+ self , key : str , value : bytes , allow_overwrite : bool = False
802+ ) -> _Status : ...
797803 def key_value_delete (self , key : str ) -> _Status : ...
798- def wait_at_barrier (self , barrier_id : str , timeout_in_ms : int ,
799- process_ids : list [int ] | None = None ) -> _Status : ...
804+ def wait_at_barrier (
805+ self ,
806+ barrier_id : str ,
807+ timeout_in_ms : int ,
808+ process_ids : list [int ] | None = None ,
809+ ) -> _Status : ...
800810 def get_live_nodes (self , process_ids : list [int ]) -> _Status : ...
801811
802812def get_distributed_runtime_service (
@@ -970,22 +980,25 @@ def is_tsan() -> bool: ...
970980def is_sanitized () -> bool : ...
971981
972982class TransferConnection :
973-
974983 def address (self ) -> str : ...
975-
976984 def _pull_flat (self , uuid , backend , avals_flat ) -> list [Any ]: ...
977985
978986class TransferServer :
979987 def _await_pull_flat (self , uuid , args : list [ArrayImpl ]): ...
980-
981988 def connect (self , address : str ) -> TransferConnection : ...
982989
983- def start_transfer_server (client : Client , address : str = "" , transport_addresses : list [str ] = [], max_num_parallel_copies : int = 0 , transfer_size : int = 0 ) -> TransferServer : ...
984-
990+ def start_transfer_server (
991+ client : Client ,
992+ address : str = "" ,
993+ transport_addresses : list [str ] = [],
994+ max_num_parallel_copies : int = 0 ,
995+ transfer_size : int = 0 ,
996+ ) -> TransferServer : ...
985997def approx_top_k_reduction_output_size (
986998 input_size : int ,
987999 rank : int ,
9881000 top_k : int ,
9891001 recall_target : float ,
9901002 aggregate_to_topk : bool | None = ...,
991- input_size_override : int | None = ...) -> tuple [int , int ]: ...
1003+ input_size_override : int | None = ...,
1004+ ) -> tuple [int , int ]: ...
0 commit comments