Skip to content

Commit

Permalink
PR #20086: [NVIDIA GPU] Fix mem p2p init in collective permute thunk
Browse files Browse the repository at this point in the history
Imported from GitHub PR #20086

Move pointer initialization to the thunk init stage instead of runtime to get rid of the runtime blocking wait.
Add a device sync point using nccl allreduce before doing memcpy to make sure all gpus arrive at the same stage. Otherwise it's possible to have data corruptions when the receiving rank hasn't arrived at the memcpy.
Copybara import of the project:

--
ba4ad04 by TJ Xu <[email protected]>:

Moved pointer init to thunk init stage and add a sync point before doing
memcpy to make sure data consistency across ranks

--
050bc59 by TJ Xu <[email protected]>:

Added e2e test for mem cpy p2p in a loop

Merging this change closes #20086

FUTURE_COPYBARA_INTEGRATE_REVIEW=#20086 from Tixxx:tixxx/memcpy_p2p_fix 050bc59
PiperOrigin-RevId: 705647424
  • Loading branch information
Tixxx authored and Google-ML-Automation committed Dec 16, 2024
1 parent 6833ecf commit 0693378
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 19 deletions.
14 changes: 12 additions & 2 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,18 @@ absl::Status ExecuteThunks(

TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params));

return MaybeSyncAndProfile(run_options, execution_timer.get(),
block_host_until_done ? main_stream : nullptr);
auto status =
MaybeSyncAndProfile(run_options, execution_timer.get(),
block_host_until_done ? main_stream : nullptr);

Thunk::CleanupParams cleanup_params{
executor,
&collective_params,
&collective_cliques,
};
TF_RETURN_IF_ERROR(thunk_sequence.Cleanup(cleanup_params));

return status;
}

namespace {
Expand Down
58 changes: 44 additions & 14 deletions xla/service/gpu/runtime/nccl_collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,48 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
if (p2p_memcpy_enabled_) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));
absl::MutexLock lock(&barrier_mutex_);
if (barrier_flags_.find(current_id) == barrier_flags_.end()) {
if (!params.stream->parent()->HostMemoryRegister(
&barrier_flags_[current_id], sizeof(uint8_t))) {
LOG(ERROR) << "Registering barrier flag failed.";
}
}

TF_ASSIGN_OR_RETURN(
std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params.buffer_allocations, {buffer_},
config_.config.operand_element_type));
TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair.";
DeviceBufferPair& buffer = device_buffers[0];
const NcclP2PConfig::SourceTargetMapEntry source_target =
NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id);

const std::optional<int64_t> source_id = source_target.source;
se::DeviceMemoryBase dest_addr = buffer.destination_buffer;

TF_RETURN_IF_ERROR(recv_ptr_map_.InitializeId(current_id));

if (source_id) {
TF_RETURN_IF_ERROR(
recv_ptr_map_.PutRecvPtr(current_id, dest_addr.opaque()));
}
}

return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::Cleanup(
const CleanupParams& params) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));

absl::MutexLock lock(&barrier_mutex_);
if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) {
LOG(ERROR) << "Unregistering barrier flag failed.";
}
}

absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
CommunicatorHandle comm_handle) {
Expand All @@ -190,6 +225,14 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
p2p_memcpy_enabled_;

TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
if (use_memcpy) {
se::DeviceMemoryBase sync_var_address =
se::DeviceMemoryBase((void*)(&barrier_flags_[current_id]));
TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce(
sync_var_address, sync_var_address, PrimitiveType::U8, 1,
ReductionKind::MIN, GpuCollectives::On(stream)));
}

return ::xla::gpu::RunCollectivePermute(
collectives, source_target, device_buffers[0], stream, comm_handle.comm,
device_string, current_id, use_memcpy, recv_ptr_map_);
Expand Down Expand Up @@ -241,16 +284,7 @@ absl::Status RunCollectivePermute(
device_string, current_id,
source_id.value_or(-1), target_id.value_or(-1));

// If all peers are local, only get/send device pointer values and invoke
// memcpy.
if (use_memcpy) {
// If sending to another peer, get the pointer value of the src addr.
// Only change the pointer value when it's different from stored one.
if (source_id) {
TF_RETURN_IF_ERROR(
recv_ptr_map.PutRecvPtr(current_id, dest_addr.opaque()));
}
} else {
if (!use_memcpy) {
// GroupStart/End API is needed only if we will issue both send & recv
// calls.
const bool is_nccl_group_needed = (target_id && source_id);
Expand Down Expand Up @@ -284,10 +318,6 @@ absl::Status RunCollectivePermute(
}
if (use_memcpy && target_id) {
TF_ASSIGN_OR_RETURN(auto recv_ptr, recv_ptr_map.GetRecvPtr(*target_id));
if (recv_ptr.IsUnavailable()) {
// TODO make BlockUntilReady support AsyncValueRef directly.
BlockUntilReady(recv_ptr.GetAsyncValue());
}

VLOG(3) << "Using memcpy, received target pointer: " << recv_ptr.get()
<< " current_id " << current_id << " target_id: " << *target_id;
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/runtime/nccl_collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {

absl::Status InitializeId(int64_t current_id) {
absl::MutexLock lock(&mutex_);
if (recv_ptrs_.find(current_id) == recv_ptrs_.end()) {
recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef<void*>();
}
recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef<void*>();
return absl::OkStatus();
}

Expand Down Expand Up @@ -102,6 +100,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
int64_t partition_count, const Buffer& buffer,
bool p2p_memcpy_enabled);
absl::Status Initialize(const InitializeParams& params) override;
absl::Status Cleanup(const CleanupParams& params) override;

static const char* GetHloOpName() { return "collective-permute-start"; }

Expand All @@ -115,6 +114,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
const NcclP2PConfig config_;
const Buffer buffer_;
RecvPtrMap recv_ptr_map_;
absl::Mutex barrier_mutex_;
std::unordered_map<int64_t, uint8_t> barrier_flags_;
bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
};
Expand Down
118 changes: 118 additions & 0 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1801,5 +1801,123 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) {
INSTANTIATE_TEST_SUITE_P(RaggedAllToAllTest, RaggedAllToAllTest,
::testing::Bool());

TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) {
absl::string_view hlo_string = R"(
HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4
None.4 {
Arg_1.6 = bf16[32,96]{1,0} parameter(1)
Arg_0.5 = bf16[32,96]{1,0} parameter(0)
collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
constant.7 = bf16[] constant(2)
broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={}
multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8)
ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10)
} // None.4
region_0.12 {
arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0
constant.17 = s32[] constant(1)
add.21 = s32[] add(get-tuple-element.14, constant.17)
get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1
get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2
call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4
get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0
get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1
ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20)
} // region_0.12
region_1.23 {
arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1
get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2
get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0
constant.28 = s32[] constant(3)
ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT
} // region_1.23
shmap_body.30 {
constant.32 = s32[] constant(0)
Arg_0.31 = bf16[32,96]{1,0} parameter(0)
constant.33 = bf16[] constant(0)
broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={}
tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34)
while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12
get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0
get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1
get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2
ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39)
} // shmap_body.30
ENTRY main.49 {
Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]}
custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30
get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0
custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual}
custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1
custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual}
custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47)
} // main.49
)";

const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_use_memcpy_local_p2p(true);
config.set_debug_options(opts);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string, config));
auto fake_arguments = xla::MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_ptrs(fake_arguments.size());
for (int i = 0; i < fake_arguments.size(); ++i) {
fake_ptrs[i] = &fake_arguments[i];
}

DeviceAssignment assn(/*replica_count=*/kNumReplicas,
/*computation_count=*/kNumPartitions);
for (int64_t i = 0; i < kNumPartitions; ++i) {
assn(0, i) = i;
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
std::move(module), fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(results.size(), kNumPartitions);

HloModuleConfig ref_config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto ref_opts = GetDebugOptionsForTest();
ref_opts.set_xla_gpu_use_memcpy_local_p2p(false);
ref_config.set_debug_options(ref_opts);
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_string, ref_config));
auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value();
std::vector<Literal*> ref_fake_ptrs(fake_ref_arguments.size());
for (int i = 0; i < fake_ref_arguments.size(); ++i) {
ref_fake_ptrs[i] = &fake_ref_arguments[i];
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> ref_results,
HloTestBase::ExecuteReplicated(
std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(ref_results.size(), kNumPartitions);
ErrorSpec error_spec{1e-5, 1e-5};
// Expect same results with and without pipelining of collectives.
for (int i = 0; i < kNumPartitions; ++i) {
EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec));
}
}
} // namespace
} // namespace xla

0 comments on commit 0693378

Please sign in to comment.