diff --git a/xla/python/ifrt/topology.h b/xla/python/ifrt/topology.h index 8d1104aca01f3..d6b60a0dfaf0b 100644 --- a/xla/python/ifrt/topology.h +++ b/xla/python/ifrt/topology.h @@ -44,6 +44,8 @@ class Topology : public llvm::RTTIExtends { virtual PjRtPlatformId platform_id() const = 0; + virtual bool is_subslice_topology() const = 0; + // Returns an unordered list of descriptions for all devices in this topology. // TODO(phawkins): consider introducing an IFRT-specific API here instead of // delegating to PJRT. diff --git a/xla/python/pjrt_ifrt/pjrt_topology.cc b/xla/python/pjrt_ifrt/pjrt_topology.cc index d27097145828a..a494afcd2fc48 100644 --- a/xla/python/pjrt_ifrt/pjrt_topology.cc +++ b/xla/python/pjrt_ifrt/pjrt_topology.cc @@ -51,6 +51,10 @@ PjRtPlatformId PjRtTopology::platform_id() const { return description_->platform_id(); } +bool PjRtTopology::is_subslice_topology() const { + return description_->is_subslice_topology(); +} + std::vector> PjRtTopology::DeviceDescriptions() const { return description_->DeviceDescriptions(); diff --git a/xla/python/pjrt_ifrt/pjrt_topology.h b/xla/python/pjrt_ifrt/pjrt_topology.h index 82fc59c8005c0..856149eb05615 100644 --- a/xla/python/pjrt_ifrt/pjrt_topology.h +++ b/xla/python/pjrt_ifrt/pjrt_topology.h @@ -46,6 +46,7 @@ class PjRtTopology final : public llvm::RTTIExtends { absl::string_view platform_name() const override; absl::string_view platform_version() const override; PjRtPlatformId platform_id() const override; + bool is_subslice_topology() const override; std::vector> DeviceDescriptions() const override; diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 7945a0dbd9cf3..26c80deb9eec3 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -858,6 +858,13 @@ NB_MODULE(xla_extension, m_nb) { .def_prop_ro( "platform_version", [](ifrt::Topology& topology) { return topology.platform_version(); }) + .def_prop_ro( + "platform_id", + [](ifrt::Topology& topology) { return topology.platform_id(); }) + .def_prop_ro("is_subslice_topology", + [](ifrt::Topology& topology) { + return topology.is_subslice_topology(); + }) .def("serialize", [](ifrt::Topology& topology) -> nb::bytes { std::string serialized = ValueOrThrow(topology.Serialize()); diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index cd6311ad06fa8..1dbe17444d768 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -738,6 +738,8 @@ class Executable: class DeviceTopology: platform: str platform_version: str + platform_id: int + is_subslice_topology: bool def _make_compile_only_devices(self) -> List[Device]: ... def serialize(self) -> bytes: ... def __getattr__(self, name: str) -> Any: ...