Skip to content

Commit

Permalink
Remove support for PjRtNamedValue without bool type.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613721254
  • Loading branch information
Jieying Luo authored and copybara-github committed Mar 7, 2024
1 parent 9a0e202 commit 1165d28
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 34 deletions.
1 change: 1 addition & 0 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cc_library(
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/profiler/lib:connected_traceme",
"@tsl//tsl/profiler/lib:context_types_hdrs",
],
Expand Down
23 changes: 9 additions & 14 deletions xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ TEST(PjrtCApiGpuKVStoreTest, CreateClientWithKVCallback) {
{"num_nodes", static_cast<int64_t>(num_nodes)},
{"node_id", static_cast<int64_t>(i)}};
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(
options, /*api_minor_version=*/30));
::pjrt::ConvertToPjRtNamedValueList(options));
TF_ASSERT_OK_AND_ASSIGN(
PJRT_Client_Create_Args create_arg,
BuildCreateArg(kv_callback_data.get(), c_options));
Expand Down Expand Up @@ -248,9 +247,8 @@ TEST(PjrtCApiGpuAllocatorTest, ValidOptionsParsing) {
if (allocator_option == "cuda_async") {
options["preallocate"] = true;
}
TF_ASSERT_OK_AND_ASSIGN(
std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30));
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options));
PJRT_Client_Create_Args create_arg;
create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
create_arg.extension_start = nullptr;
Expand All @@ -277,9 +275,8 @@ TEST(PjrtCApiGpuAllocatorTest, InvalidAllocatorOptionsParsing) {
{"memory_fraction", 0.5f},
{"preallocate", true},
};
TF_ASSERT_OK_AND_ASSIGN(
std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30));
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options));
PJRT_Client_Create_Args create_arg;
create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
create_arg.extension_start = nullptr;
Expand Down Expand Up @@ -312,9 +309,8 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) {
{"allocator", static_cast<std::string>("default")},
{"visible_devices", xla::PjRtValueType(std::vector<int64_t>{0, 1})},
};
TF_ASSERT_OK_AND_ASSIGN(
std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30));
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options));
PJRT_Client_Create_Args create_arg;
create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
create_arg.extension_start = nullptr;
Expand Down Expand Up @@ -354,9 +350,8 @@ TEST(PjrtCApiPlatformNameTest, UnavailablePlatformName) {
{"allocator", static_cast<std::string>("default")},
{"visible_devices", xla::PjRtValueType(std::vector<int64_t>{0, 1})},
};
TF_ASSERT_OK_AND_ASSIGN(
std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options, /*api_minor_version=*/30));
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_options,
::pjrt::ConvertToPjRtNamedValueList(options));
PJRT_Client_Create_Args create_arg;
create_arg.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
create_arg.extension_start = nullptr;
Expand Down
12 changes: 5 additions & 7 deletions xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ limitations under the License.
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/connected_traceme.h"
#include "tsl/profiler/lib/context_types.h"

Expand Down Expand Up @@ -425,8 +426,7 @@ xla::PjRtFuture<absl::Status> ConvertCEventToCppFuture(PJRT_Event* c_event,
}

static absl::StatusOr<PJRT_NamedValue> ConvertToPjRtNamedValue(
const std::string& name, const xla::PjRtValueType& value,
int api_minor_version) {
const std::string& name, const xla::PjRtValueType& value) {
PJRT_NamedValue c_value;
c_value.struct_size = PJRT_NamedValue_STRUCT_SIZE;
c_value.extension_start = nullptr;
Expand Down Expand Up @@ -465,14 +465,12 @@ static absl::StatusOr<PJRT_NamedValue> ConvertToPjRtNamedValue(
}

absl::StatusOr<std::vector<PJRT_NamedValue>> ConvertToPjRtNamedValueList(
const absl::flat_hash_map<std::string, xla::PjRtValueType>& cpp_value_map,
int api_minor_version) {
const absl::flat_hash_map<std::string, xla::PjRtValueType>& cpp_value_map) {
std::vector<PJRT_NamedValue> c_value_list;
c_value_list.reserve(cpp_value_map.size());
for (const auto& [name, value] : cpp_value_map) {
TF_ASSIGN_OR_RETURN(
PJRT_NamedValue c_value,
ConvertToPjRtNamedValue(name, value, api_minor_version));
TF_ASSIGN_OR_RETURN(PJRT_NamedValue c_value,
ConvertToPjRtNamedValue(name, value));
c_value_list.push_back(c_value);
}
return c_value_list;
Expand Down
3 changes: 1 addition & 2 deletions xla/pjrt/c/pjrt_c_api_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ xla::PjRtFuture<xla::Status> ConvertCEventToCppFuture(PJRT_Event* c_event,
// `cpp_value_map`, so `cpp_value_map` must outlive the returned list. It will
// raise errors for unsupported PjRtValueType.
absl::StatusOr<std::vector<PJRT_NamedValue>> ConvertToPjRtNamedValueList(
const absl::flat_hash_map<std::string, xla::PjRtValueType>& cpp_value_map,
int api_minor_version);
const absl::flat_hash_map<std::string, xla::PjRtValueType>& cpp_value_map);

absl::flat_hash_map<std::string, xla::PjRtValueType>
ConvertFromPjRtNamedValueList(const PJRT_NamedValue* c_value_list,
Expand Down
5 changes: 2 additions & 3 deletions xla/pjrt/c/pjrt_c_api_helpers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ TEST(PjRtCApiHelperTest, ConvertValidPjRtValueType) {
{"int64_list", int64_list},
{"float", static_cast<float>(1.0)}};

TF_ASSERT_OK_AND_ASSIGN(
std::vector<PJRT_NamedValue> c_map,
ConvertToPjRtNamedValueList(original_cpp_map, /*api_minor_version=*/30));
TF_ASSERT_OK_AND_ASSIGN(std::vector<PJRT_NamedValue> c_map,
ConvertToPjRtNamedValueList(original_cpp_map));
auto converted_back_cpp_map =
ConvertFromPjRtNamedValueList(c_map.data(), c_map.size());

Expand Down
12 changes: 4 additions & 8 deletions xla/pjrt/pjrt_c_api_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2204,10 +2204,8 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCApiClient(
PJRT_Client_Create_Args init_args;
init_args.struct_size = PJRT_Client_Create_Args_STRUCT_SIZE;
init_args.extension_start = nullptr;
TF_ASSIGN_OR_RETURN(
std::vector<PJRT_NamedValue> c_options,
pjrt::ConvertToPjRtNamedValueList(create_options,
c_api->pjrt_api_version.minor_version));
TF_ASSIGN_OR_RETURN(std::vector<PJRT_NamedValue> c_options,
pjrt::ConvertToPjRtNamedValueList(create_options));
init_args.create_options = c_options.data();
init_args.num_options = c_options.size();

Expand Down Expand Up @@ -2243,10 +2241,8 @@ absl::StatusOr<std::unique_ptr<PjRtTopologyDescription>> GetCApiTopology(
PJRT_TopologyDescription_Create_Args init_args;
init_args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE;
init_args.extension_start = nullptr;
TF_ASSIGN_OR_RETURN(
std::vector<PJRT_NamedValue> c_options,
pjrt::ConvertToPjRtNamedValueList(create_options,
c_api->pjrt_api_version.minor_version));
TF_ASSIGN_OR_RETURN(std::vector<PJRT_NamedValue> c_options,
pjrt::ConvertToPjRtNamedValueList(create_options));
init_args.create_options = c_options.data();
init_args.num_options = c_options.size();
init_args.topology_name = topology_name.data();
Expand Down

0 comments on commit 1165d28

Please sign in to comment.