Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IFRT proxy: Add profiler spans to all entrypoints at the client. #20325

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
IFRT proxy: Add profiler spans to all entrypoints at the client.
PiperOrigin-RevId: 704444588
  • Loading branch information
Google-ML-Automation committed Dec 9, 2024
commit c179e1c792236145c819193bbde6b84d78e7b0c7
4 changes: 4 additions & 0 deletions xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -255,6 +256,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -331,6 +333,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@tsl//tsl/platform:status_to_from_proto",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -405,6 +408,7 @@ cc_library(
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:status_to_from_proto",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:traceme",
],
)

Expand Down
33 changes: 32 additions & 1 deletion xla/python/ifrt_proxy/client/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -113,6 +114,7 @@ Array::MakeArrayFromHostBuffer(
return absl::UnimplementedError(
"String arrays are not supported in ifrt-proxy version < 9");
}
tsl::profiler::TraceMe traceme("IfrtProxySerializeStringHostBuffer");
TF_ASSIGN_OR_RETURN(
std::shared_ptr<std::string> owned_data,
SerializeStringHostBuffer(absl::MakeConstSpan(
Expand All @@ -127,6 +129,12 @@ Array::MakeArrayFromHostBuffer(
}
};
}
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
[s = mem_region.size(), semantics]() {
return tsl::profiler::TraceMeEncode(
"IfrtProxyEntrypointMakeArrayFromHostBuffer",
{{"size", s}, {"semantics", static_cast<int>(semantics)}});
});

const uint64_t host_buffer_handle = rpc_helper->NextHandle();

Expand Down Expand Up @@ -226,6 +234,9 @@ void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) {
}

Future<> Array::GetReadyFuture() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointArrayGetReadyFuture");

auto req = std::make_unique<CheckValueReadyRequest>();
req->add_value_handles(handle_.handle);

Expand Down Expand Up @@ -260,6 +271,8 @@ Future<> Array::Delete() {
}

bool Array::IsDeleted() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointIsDeleted");
if (GetGlobalClientFlags()->array_is_deleted_hack) {
return false;
}
Expand Down Expand Up @@ -287,6 +300,14 @@ Array::AssembleArrayFromSingleDeviceArrays(
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
[n_arrays = arrays.size(), single_device_shard_semantics]() {
return tsl::profiler::TraceMeEncode(
"IfrtProxyEntrypointAssembleArrayFromSingleDeviceArrays",
{{"n_arrays", n_arrays},
{"sds_semantics",
static_cast<int>(single_device_shard_semantics)}});
});
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAddressableShards &&
rpc_helper->version().protocol_version() < 8) {
Expand Down Expand Up @@ -338,6 +359,10 @@ Array::RemapArrays(xla::ifrt::Client* client,
std::shared_ptr<RpcHelper> rpc_helper, const RemapPlan& plan,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_arrays = arrays.size()]() {
return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointRemapArrays",
{{"n_arrays", n_arrays}});
});
auto req = std::make_unique<RemapArraysRequest>();
TF_RET_CHECK(!arrays.empty());
TF_ASSIGN_OR_RETURN(*req->mutable_plan(), plan.ToProto());
Expand Down Expand Up @@ -393,6 +418,8 @@ absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
Array::DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointDisassembleIntoSingleDeviceArrays");
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAddressableShards &&
rpc_helper_->version().protocol_version() < 8) {
Expand Down Expand Up @@ -446,6 +473,8 @@ Array::DisassembleIntoSingleDeviceArrays(

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> Array::FullyReplicatedShard(
ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointFullyReplicatedShard");
auto req = std::make_unique<FullyReplicatedShardRequest>();
req->set_array_handle(handle_.handle);
req->set_copy_semantics(ToArrayCopySemanticsProto(semantics));
Expand Down Expand Up @@ -481,6 +510,8 @@ absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> Array::FullyReplicatedShard(
Future<> Array::CopyToStringHostBuffer(
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointCopyToStringHostBuffer");
if (rpc_helper_->version().protocol_version() < 9) {
return Future<>(absl::UnimplementedError(
"String arrays are not supported in ifrt-proxy version < 9"));
Expand Down Expand Up @@ -540,7 +571,7 @@ Future<> Array::CopyToHostBuffer(
if (dtype_.kind() == DType::kString) {
return CopyToStringHostBuffer(data, byte_strides, semantics);
}

tsl::profiler::TraceMe traceme("IfrtProxyEntrypointCopyToHostBuffer");
const auto mem_region = ArrayMemRegion::FromZerothElementPointer(
/*zeroth_element=*/data, dtype_, shape_, byte_strides);
if (!mem_region.ok()) {
Expand Down
12 changes: 12 additions & 0 deletions xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand All @@ -66,6 +67,7 @@ char Client::ID = 0;

absl::StatusOr<std::unique_ptr<Client>> Client::Create(
std::shared_ptr<RpcHelper> rpc_helper, InitResponse init_response) {
tsl::profiler::TraceMe traceme("IfrtProxyEntrypointClientCreate");
absl::flat_hash_set<int> addressable_device_ids(
init_response.addressable_device_ids().begin(),
init_response.addressable_device_ids().end());
Expand Down Expand Up @@ -254,6 +256,10 @@ Client::CopyArrays(
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
std::optional<tsl::RCReference<xla::ifrt::DeviceList>> devices,
std::optional<MemoryKind> memory_kind, ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_arrays = arrays.size()]() {
return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointCopyArrays",
{{"n_arrays", n_arrays}});
});
if (arrays.empty()) {
return std::vector<tsl::RCReference<xla::ifrt::Array>>();
}
Expand Down Expand Up @@ -334,6 +340,10 @@ Client::RemapArrays(const RemapPlan& plan,

xla::ifrt::Future<> Client::GetReadyFuture(
absl::Span<const tsl::RCReference<xla::ifrt::Value>> values) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_values = values.size()]() {
return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointGetReadyFuture",
{{"n_values", n_values}});
});
absl::InlinedVector<Future<>, 1> futures;

auto req = std::make_unique<CheckValueReadyRequest>();
Expand Down Expand Up @@ -364,6 +374,8 @@ absl::Span<xla::ifrt::Device* const> Client::GetAllDevices() const {

absl::StatusOr<DeviceAssignment> Client::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointGetDefaultDeviceAssignment");
auto req = std::make_unique<GetDefaultDeviceAssignmentRequest>();
req->set_num_replicas(num_replicas);
req->set_num_partitions(num_partitions);
Expand Down
13 changes: 11 additions & 2 deletions xla/python/ifrt_proxy/client/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/status_to_from_proto.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand All @@ -57,8 +58,16 @@ absl::StatusOr<std::unique_ptr<xla::ifrt::LoadedExecutable>> Compiler::Compile(
std::unique_ptr<Program> program,
std::unique_ptr<xla::ifrt::CompileOptions> options) {
auto request = std::make_unique<CompileRequest>();
TF_ASSIGN_OR_RETURN(*request->mutable_program(),
Serialize(*program, /*options=*/nullptr));
{
tsl::profiler::TraceMe traceme("IfrtProxyProgramSerialize");
TF_ASSIGN_OR_RETURN(*request->mutable_program(),
Serialize(*program, /*options=*/nullptr));
}
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
[prog_size = request->program().data().size()]() {
return tsl::profiler::TraceMeEncode(
"IfrtProxyEntrypointCompilerCompile", {{"prog_size", prog_size}});
});

// Extract host callbacks from the XLA compile options. `XlaCompileOptions`'s
// SerDes fails when it contains host callbacks, so the following
Expand Down
23 changes: 23 additions & 0 deletions xla/python/ifrt_proxy/client/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
#include "tsl/platform/status_to_from_proto.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -272,6 +273,9 @@ LoadedExecutable::LoadedExecutable(
}
}

tsl::profiler::TraceMe traceme_ifrt_entrypoint(

"IfrtProxyEntrypointLoadedExecutableCreate");
// Asynchronously fetch shardings. Since users of `LoadedExecutable` typically
// require sharding information to invoke the executable, it is beneficial to
// eagerly schedule this fetch since, in some implementations, it may take a
Expand Down Expand Up @@ -362,6 +366,9 @@ LoadedExecutable::LoadedExecutable(
}

LoadedExecutable::~LoadedExecutable() {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableDestruct");

auto req = std::make_unique<LoadedExecutableDestructRequest>();
req->set_loaded_executable_handle(handle_);

Expand Down Expand Up @@ -406,6 +413,8 @@ absl::StatusOr<CompiledMemoryStats> LoadedExecutable::GetCompiledMemoryStats()

std::optional<std::vector<OpSharding>> LoadedExecutable::GetParameterShardings()
const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetParameterShardings");
auto info = metadata_future_.Await();
if (!info.ok()) {
return std::nullopt;
Expand All @@ -415,6 +424,8 @@ std::optional<std::vector<OpSharding>> LoadedExecutable::GetParameterShardings()

std::optional<std::vector<OpSharding>> LoadedExecutable::GetOutputShardings()
const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputShardings");
auto info = metadata_future_.Await();
if (!info.ok()) {
return std::nullopt;
Expand All @@ -424,6 +435,8 @@ std::optional<std::vector<OpSharding>> LoadedExecutable::GetOutputShardings()

absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
LoadedExecutable::GetParameterLayouts() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetParameterLayouts");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
TF_RETURN_IF_ERROR(info->parameter_layouts.status());

Expand All @@ -437,6 +450,8 @@ LoadedExecutable::GetParameterLayouts() const {

absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
LoadedExecutable::GetOutputLayouts() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputLayouts");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
TF_RETURN_IF_ERROR(info->output_layouts.status());

Expand All @@ -450,6 +465,8 @@ LoadedExecutable::GetOutputLayouts() const {

absl::StatusOr<std::vector<std::vector<absl::string_view>>>
LoadedExecutable::GetOutputMemoryKinds() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputMemoryKinds");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
return info->output_memory_kinds;
}
Expand All @@ -471,6 +488,8 @@ LoadedExecutable::Execute(
absl::Span<tsl::RCReference<xla::ifrt::Array>> args,
const ExecuteOptions& options,
std::optional<tsl::RCReference<xla::ifrt::DeviceList>> devices) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableExecute");
auto req = std::make_unique<LoadedExecutableExecuteRequest>();
req->set_loaded_executable_handle(handle_);
for (const auto& arg : args) {
Expand Down Expand Up @@ -557,6 +576,8 @@ LoadedExecutable::Execute(
}

Future<> LoadedExecutable::Delete() {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableDelete");
auto req = std::make_unique<LoadedExecutableDeleteRequest>();
req->set_loaded_executable_handle(handle_);

Expand All @@ -580,6 +601,8 @@ Future<> LoadedExecutable::Delete() {
}

bool LoadedExecutable::IsDeleted() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableIsDeleted");
auto req = std::make_unique<LoadedExecutableIsDeletedRequest>();
req->set_loaded_executable_handle(handle_);

Expand Down
Loading