Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] Fix mem p2p init in collective permute thunk #20086

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,18 @@ absl::Status ExecuteThunks(

TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params));

return MaybeSyncAndProfile(run_options, execution_timer.get(),
block_host_until_done ? main_stream : nullptr);
auto status =
MaybeSyncAndProfile(run_options, execution_timer.get(),
block_host_until_done ? main_stream : nullptr);

Thunk::CleanupParams cleanup_params{
executor,
&collective_params,
&collective_cliques,
};
TF_RETURN_IF_ERROR(thunk_sequence.Cleanup(cleanup_params));

return status;
}

namespace {
Expand Down
58 changes: 44 additions & 14 deletions xla/service/gpu/runtime/nccl_collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,48 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
if (p2p_memcpy_enabled_) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));
absl::MutexLock lock(&barrier_mutex_);
if (barrier_flags_.find(current_id) == barrier_flags_.end()) {
if (!params.stream->parent()->HostMemoryRegister(
&barrier_flags_[current_id], sizeof(uint8_t))) {
LOG(ERROR) << "Registering barrier flag failed.";
}
}

TF_ASSIGN_OR_RETURN(
std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params.buffer_allocations, {buffer_},
config_.config.operand_element_type));
TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair.";
DeviceBufferPair& buffer = device_buffers[0];
const NcclP2PConfig::SourceTargetMapEntry source_target =
NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id);

const std::optional<int64_t> source_id = source_target.source;
se::DeviceMemoryBase dest_addr = buffer.destination_buffer;

TF_RETURN_IF_ERROR(recv_ptr_map_.InitializeId(current_id));

if (source_id) {
TF_RETURN_IF_ERROR(
recv_ptr_map_.PutRecvPtr(current_id, dest_addr.opaque()));
}
}

return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::Cleanup(
const CleanupParams& params) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));

absl::MutexLock lock(&barrier_mutex_);
if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) {
LOG(ERROR) << "Unregistering barrier flag failed.";
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several TensorFlow tests are failing with:
error: non-void function does not return a value in all control paths [-Werror,-Wreturn-type]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a returned status for the cleanup function


absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
CommunicatorHandle comm_handle) {
Expand All @@ -190,6 +225,14 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
p2p_memcpy_enabled_;

TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
if (use_memcpy) {
se::DeviceMemoryBase sync_var_address =
se::DeviceMemoryBase((void*)(&barrier_flags_[current_id]));
TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce(
sync_var_address, sync_var_address, PrimitiveType::U8, 1,
ReductionKind::MIN, GpuCollectives::On(stream)));
}

return ::xla::gpu::RunCollectivePermute(
collectives, source_target, device_buffers[0], stream, comm_handle.comm,
device_string, current_id, use_memcpy, recv_ptr_map_);
Expand Down Expand Up @@ -241,16 +284,7 @@ absl::Status RunCollectivePermute(
device_string, current_id,
source_id.value_or(-1), target_id.value_or(-1));

// If all peers are local, only get/send device pointer values and invoke
// memcpy.
if (use_memcpy) {
// If sending to another peer, get the pointer value of the src addr.
// Only change the pointer value when it's different from stored one.
if (source_id) {
TF_RETURN_IF_ERROR(
recv_ptr_map.PutRecvPtr(current_id, dest_addr.opaque()));
}
} else {
if (!use_memcpy) {
// GroupStart/End API is needed only if we will issue both send & recv
// calls.
const bool is_nccl_group_needed = (target_id && source_id);
Expand Down Expand Up @@ -284,10 +318,6 @@ absl::Status RunCollectivePermute(
}
if (use_memcpy && target_id) {
TF_ASSIGN_OR_RETURN(auto recv_ptr, recv_ptr_map.GetRecvPtr(*target_id));
if (recv_ptr.IsUnavailable()) {
// TODO make BlockUntilReady support AsyncValueRef directly.
BlockUntilReady(recv_ptr.GetAsyncValue());
}

VLOG(3) << "Using memcpy, received target pointer: " << recv_ptr.get()
<< " current_id " << current_id << " target_id: " << *target_id;
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/runtime/nccl_collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {

absl::Status InitializeId(int64_t current_id) {
absl::MutexLock lock(&mutex_);
if (recv_ptrs_.find(current_id) == recv_ptrs_.end()) {
recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef<void*>();
}
recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef<void*>();
return absl::OkStatus();
}

Expand Down Expand Up @@ -102,6 +100,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
int64_t partition_count, const Buffer& buffer,
bool p2p_memcpy_enabled);
absl::Status Initialize(const InitializeParams& params) override;
absl::Status Cleanup(const CleanupParams& params) override;

static const char* GetHloOpName() { return "collective-permute-start"; }

Expand All @@ -115,6 +114,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
const NcclP2PConfig config_;
const Buffer buffer_;
RecvPtrMap recv_ptr_map_;
absl::Mutex barrier_mutex_;
std::unordered_map<int64_t, uint8_t> barrier_flags_;
bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
};
Expand Down
118 changes: 118 additions & 0 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1801,5 +1801,123 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) {
INSTANTIATE_TEST_SUITE_P(RaggedAllToAllTest, RaggedAllToAllTest,
::testing::Bool());

TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) {
absl::string_view hlo_string = R"(
HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4

None.4 {
Arg_1.6 = bf16[32,96]{1,0} parameter(1)
Arg_0.5 = bf16[32,96]{1,0} parameter(0)
collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
constant.7 = bf16[] constant(2)
broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={}
multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8)
ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10)
} // None.4

region_0.12 {
arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0
constant.17 = s32[] constant(1)
add.21 = s32[] add(get-tuple-element.14, constant.17)
get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1
get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2
call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4
get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0
get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1
ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20)
} // region_0.12

region_1.23 {
arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1
get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2
get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0
constant.28 = s32[] constant(3)
ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT
} // region_1.23

shmap_body.30 {
constant.32 = s32[] constant(0)
Arg_0.31 = bf16[32,96]{1,0} parameter(0)
constant.33 = bf16[] constant(0)
broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={}
tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34)
while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12
get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0
get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1
get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2
ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39)
} // shmap_body.30

ENTRY main.49 {
Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]}
custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30
get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0
custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual}
custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1
custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual}
custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47)
} // main.49
)";

const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_use_memcpy_local_p2p(true);
config.set_debug_options(opts);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string, config));
auto fake_arguments = xla::MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_ptrs(fake_arguments.size());
for (int i = 0; i < fake_arguments.size(); ++i) {
fake_ptrs[i] = &fake_arguments[i];
}

DeviceAssignment assn(/*replica_count=*/kNumReplicas,
/*computation_count=*/kNumPartitions);
for (int64_t i = 0; i < kNumPartitions; ++i) {
assn(0, i) = i;
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
std::move(module), fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(results.size(), kNumPartitions);

HloModuleConfig ref_config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto ref_opts = GetDebugOptionsForTest();
ref_opts.set_xla_gpu_use_memcpy_local_p2p(false);
ref_config.set_debug_options(ref_opts);
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_string, ref_config));
auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value();
std::vector<Literal*> ref_fake_ptrs(fake_ref_arguments.size());
for (int i = 0; i < fake_ref_arguments.size(); ++i) {
ref_fake_ptrs[i] = &fake_ref_arguments[i];
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> ref_results,
HloTestBase::ExecuteReplicated(
std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(ref_results.size(), kNumPartitions);
ErrorSpec error_spec{1e-5, 1e-5};
// Expect same results with and without pipelining of collectives.
for (int i = 0; i < kNumPartitions; ++i) {
EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec));
}
}
} // namespace
} // namespace xla