Skip to content

Commit

Permalink
[MPMD-GPU] Add is_subslice_topology to the IFRT Topology.
Browse files Browse the repository at this point in the history
This is to make IFRT Topology fields consistent with `PjRtTopologyDescriptionProto` fields.

PiperOrigin-RevId: 694647266
  • Loading branch information
changhuilin authored and Google-ML-Automation committed Nov 8, 2024
1 parent 3ed5632 commit 58a435a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xla/python/ifrt/topology.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Topology : public llvm::RTTIExtends<Topology, llvm::RTTIRoot> {

virtual PjRtPlatformId platform_id() const = 0;

virtual bool is_subslice_topology() const = 0;

// Returns an unordered list of descriptions for all devices in this topology.
// TODO(phawkins): consider introducing an IFRT-specific API here instead of
// delegating to PJRT.
Expand Down
4 changes: 4 additions & 0 deletions xla/python/pjrt_ifrt/pjrt_topology.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ PjRtPlatformId PjRtTopology::platform_id() const {
return description_->platform_id();
}

bool PjRtTopology::is_subslice_topology() const {
return description_->is_subslice_topology();
}

std::vector<std::unique_ptr<const PjRtDeviceDescription>>
PjRtTopology::DeviceDescriptions() const {
return description_->DeviceDescriptions();
Expand Down
1 change: 1 addition & 0 deletions xla/python/pjrt_ifrt/pjrt_topology.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class PjRtTopology final : public llvm::RTTIExtends<PjRtTopology, Topology> {
absl::string_view platform_name() const override;
absl::string_view platform_version() const override;
PjRtPlatformId platform_id() const override;
bool is_subslice_topology() const override;

std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
const override;
Expand Down
7 changes: 7 additions & 0 deletions xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,13 @@ NB_MODULE(xla_extension, m_nb) {
.def_prop_ro(
"platform_version",
[](ifrt::Topology& topology) { return topology.platform_version(); })
.def_prop_ro(
"platform_id",
[](ifrt::Topology& topology) { return topology.platform_id(); })
.def_prop_ro("is_subslice_topology",
[](ifrt::Topology& topology) {
return topology.is_subslice_topology();
})
.def("serialize",
[](ifrt::Topology& topology) -> nb::bytes {
std::string serialized = ValueOrThrow(topology.Serialize());
Expand Down
2 changes: 2 additions & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ class Executable:
class DeviceTopology:
platform: str
platform_version: str
platform_id: int
is_subslice_topology: bool
def _make_compile_only_devices(self) -> List[Device]: ...
def serialize(self) -> bytes: ...
def __getattr__(self, name: str) -> Any: ...
Expand Down

0 comments on commit 58a435a

Please sign in to comment.