Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,18 @@ absl::Status ExecuteThunks(

TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params));

return MaybeSyncAndProfile(run_options, execution_timer.get(),
block_host_until_done ? main_stream : nullptr);
auto status =
MaybeSyncAndProfile(run_options, execution_timer.get(),
block_host_until_done ? main_stream : nullptr);

Thunk::CleanupParams cleanup_params{
executor,
&collective_params,
&collective_cliques,
};
TF_RETURN_IF_ERROR(thunk_sequence.Cleanup(cleanup_params));

return status;
}

namespace {
Expand Down
59 changes: 45 additions & 14 deletions xla/service/gpu/runtime/nccl_collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,49 @@ absl::Status NcclCollectivePermuteStartThunk::Initialize(
if (p2p_memcpy_enabled_) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));
absl::MutexLock lock(&barrier_mutex_);
if (barrier_flags_.find(current_id) == barrier_flags_.end()) {
if (!params.stream->parent()->HostMemoryRegister(
&barrier_flags_[current_id], sizeof(uint8_t))) {
LOG(ERROR) << "Registering barrier flag failed.";
}
}

TF_ASSIGN_OR_RETURN(
std::vector<DeviceBufferPair> device_buffers,
ConvertToDeviceBuffers(params.buffer_allocations, {buffer_},
config_.config.operand_element_type));
TF_RET_CHECK(device_buffers.size() == 1) << "Expected one buffer pair.";
DeviceBufferPair& buffer = device_buffers[0];
const NcclP2PConfig::SourceTargetMapEntry source_target =
NcclP2PConfig::GetSourceTarget(config_.id_to_source_target, current_id);

const std::optional<int64_t> source_id = source_target.source;
se::DeviceMemoryBase dest_addr = buffer.destination_buffer;

TF_RETURN_IF_ERROR(recv_ptr_map_.InitializeId(current_id));

if (source_id) {
TF_RETURN_IF_ERROR(
recv_ptr_map_.PutRecvPtr(current_id, dest_addr.opaque()));
}
}

return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::Cleanup(
const CleanupParams& params) {
TF_ASSIGN_OR_RETURN(const int64_t current_id,
GetCurrentId(params.collective_params, config_));

absl::MutexLock lock(&barrier_mutex_);
if (!params.executor->HostMemoryUnregister(&barrier_flags_[current_id])) {
LOG(ERROR) << "Unregistering barrier flag failed.";
}
return absl::OkStatus();
}

absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
const ExecuteParams& params, se::Stream& stream,
CommunicatorHandle comm_handle) {
Expand All @@ -190,6 +226,14 @@ absl::Status NcclCollectivePermuteStartThunk::RunNcclCollective(
p2p_memcpy_enabled_;

TF_ASSIGN_OR_RETURN(GpuCollectives * collectives, GetGpuCollectives(params));
if (use_memcpy) {
se::DeviceMemoryBase sync_var_address =
se::DeviceMemoryBase((void*)(&barrier_flags_[current_id]));
TF_RETURN_IF_ERROR(comm_handle.comm->AllReduce(
sync_var_address, sync_var_address, PrimitiveType::U8, 1,
ReductionKind::MIN, GpuCollectives::On(stream)));
}

return ::xla::gpu::RunCollectivePermute(
collectives, source_target, device_buffers[0], stream, comm_handle.comm,
device_string, current_id, use_memcpy, recv_ptr_map_);
Expand Down Expand Up @@ -241,16 +285,7 @@ absl::Status RunCollectivePermute(
device_string, current_id,
source_id.value_or(-1), target_id.value_or(-1));

// If all peers are local, only get/send device pointer values and invoke
// memcpy.
if (use_memcpy) {
// If sending to another peer, get the pointer value of the src addr.
// Only change the pointer value when it's different from stored one.
if (source_id) {
TF_RETURN_IF_ERROR(
recv_ptr_map.PutRecvPtr(current_id, dest_addr.opaque()));
}
} else {
if (!use_memcpy) {
// GroupStart/End API is needed only if we will issue both send & recv
// calls.
const bool is_nccl_group_needed = (target_id && source_id);
Expand Down Expand Up @@ -284,10 +319,6 @@ absl::Status RunCollectivePermute(
}
if (use_memcpy && target_id) {
TF_ASSIGN_OR_RETURN(auto recv_ptr, recv_ptr_map.GetRecvPtr(*target_id));
if (recv_ptr.IsUnavailable()) {
// TODO make BlockUntilReady support AsyncValueRef directly.
BlockUntilReady(recv_ptr.GetAsyncValue());
}

VLOG(3) << "Using memcpy, received target pointer: " << recv_ptr.get()
<< " current_id " << current_id << " target_id: " << *target_id;
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/runtime/nccl_collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {

absl::Status InitializeId(int64_t current_id) {
absl::MutexLock lock(&mutex_);
if (recv_ptrs_.find(current_id) == recv_ptrs_.end()) {
recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef<void*>();
}
recv_ptrs_[current_id] = tsl::MakeUnconstructedAsyncValueRef<void*>();
return absl::OkStatus();
}

Expand Down Expand Up @@ -102,6 +100,7 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
int64_t partition_count, const Buffer& buffer,
bool p2p_memcpy_enabled);
absl::Status Initialize(const InitializeParams& params) override;
absl::Status Cleanup(const CleanupParams& params) override;

static const char* GetHloOpName() { return "collective-permute-start"; }

Expand All @@ -115,6 +114,8 @@ class NcclCollectivePermuteStartThunk : public NcclCollectiveThunk {
const NcclP2PConfig config_;
const Buffer buffer_;
RecvPtrMap recv_ptr_map_;
absl::Mutex barrier_mutex_;
std::unordered_map<int64_t, uint8_t> barrier_flags_;
bool p2p_memcpy_enabled_ = false;
int64_t device_count_;
};
Expand Down
118 changes: 118 additions & 0 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1801,5 +1801,123 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) {
INSTANTIATE_TEST_SUITE_P(RaggedAllToAllTest, RaggedAllToAllTest,
::testing::Bool());

TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) {
absl::string_view hlo_string = R"(
HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4

None.4 {
Arg_1.6 = bf16[32,96]{1,0} parameter(1)
Arg_0.5 = bf16[32,96]{1,0} parameter(0)
collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
constant.7 = bf16[] constant(2)
broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={}
multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8)
ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10)
} // None.4

region_0.12 {
arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0
constant.17 = s32[] constant(1)
add.21 = s32[] add(get-tuple-element.14, constant.17)
get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1
get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2
call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4
get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0
get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1
ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20)
} // region_0.12

region_1.23 {
arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1
get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2
get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0
constant.28 = s32[] constant(3)
ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT
} // region_1.23

shmap_body.30 {
constant.32 = s32[] constant(0)
Arg_0.31 = bf16[32,96]{1,0} parameter(0)
constant.33 = bf16[] constant(0)
broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={}
tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34)
while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12
get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0
get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1
get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2
ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39)
} // shmap_body.30

ENTRY main.49 {
Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]}
custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30
get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0
custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual}
custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1
custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual}
custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47)
} // main.49
)";

const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_use_memcpy_local_p2p(true);
config.set_debug_options(opts);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string, config));
auto fake_arguments = xla::MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_ptrs(fake_arguments.size());
for (int i = 0; i < fake_arguments.size(); ++i) {
fake_ptrs[i] = &fake_arguments[i];
}

DeviceAssignment assn(/*replica_count=*/kNumReplicas,
/*computation_count=*/kNumPartitions);
for (int64_t i = 0; i < kNumPartitions; ++i) {
assn(0, i) = i;
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
std::move(module), fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(results.size(), kNumPartitions);

HloModuleConfig ref_config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto ref_opts = GetDebugOptionsForTest();
ref_opts.set_xla_gpu_use_memcpy_local_p2p(false);
ref_config.set_debug_options(ref_opts);
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_string, ref_config));
auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value();
std::vector<Literal*> ref_fake_ptrs(fake_ref_arguments.size());
for (int i = 0; i < fake_ref_arguments.size(); ++i) {
ref_fake_ptrs[i] = &fake_ref_arguments[i];
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> ref_results,
HloTestBase::ExecuteReplicated(
std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(ref_results.size(), kNumPartitions);
ErrorSpec error_spec{1e-5, 1e-5};
// Expect same results with and without pipelining of collectives.
for (int i = 0; i < kNumPartitions; ++i) {
EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec));
}
}
} // namespace
} // namespace xla