Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] Fix mem p2p init in collective permute thunk #20086

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added e2e test for mem cpy p2p in a loop
  • Loading branch information
Tixxx committed Dec 12, 2024
commit 050bc59c02732da728fe43bd6c4c12702d070c2c
118 changes: 118 additions & 0 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1801,5 +1801,123 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_8GPUs) {
INSTANTIATE_TEST_SUITE_P(RaggedAllToAllTest, RaggedAllToAllTest,
::testing::Bool());

TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) {
absl::string_view hlo_string = R"(
HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4

None.4 {
Arg_1.6 = bf16[32,96]{1,0} parameter(1)
Arg_0.5 = bf16[32,96]{1,0} parameter(0)
collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
constant.7 = bf16[] constant(2)
broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={}
multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8)
ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10)
} // None.4

region_0.12 {
arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0
constant.17 = s32[] constant(1)
add.21 = s32[] add(get-tuple-element.14, constant.17)
get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1
get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2
call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4
get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0
get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1
ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20)
} // region_0.12

region_1.23 {
arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1
get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2
get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0
constant.28 = s32[] constant(3)
ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT
} // region_1.23

shmap_body.30 {
constant.32 = s32[] constant(0)
Arg_0.31 = bf16[32,96]{1,0} parameter(0)
constant.33 = bf16[] constant(0)
broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={}
tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34)
while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12
get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0
get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1
get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2
ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39)
} // shmap_body.30

ENTRY main.49 {
Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]}
custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30
get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0
custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual}
custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1
custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual}
custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47)
} // main.49
)";

const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_use_memcpy_local_p2p(true);
config.set_debug_options(opts);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string, config));
auto fake_arguments = xla::MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_ptrs(fake_arguments.size());
for (int i = 0; i < fake_arguments.size(); ++i) {
fake_ptrs[i] = &fake_arguments[i];
}

DeviceAssignment assn(/*replica_count=*/kNumReplicas,
/*computation_count=*/kNumPartitions);
for (int64_t i = 0; i < kNumPartitions; ++i) {
assn(0, i) = i;
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
std::move(module), fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(results.size(), kNumPartitions);

HloModuleConfig ref_config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto ref_opts = GetDebugOptionsForTest();
ref_opts.set_xla_gpu_use_memcpy_local_p2p(false);
ref_config.set_debug_options(ref_opts);
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_string, ref_config));
auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value();
std::vector<Literal*> ref_fake_ptrs(fake_ref_arguments.size());
for (int i = 0; i < fake_ref_arguments.size(); ++i) {
ref_fake_ptrs[i] = &fake_ref_arguments[i];
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> ref_results,
HloTestBase::ExecuteReplicated(
std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(ref_results.size(), kNumPartitions);
ErrorSpec error_spec{1e-5, 1e-5};
// Expect same results with and without pipelining of collectives.
for (int i = 0; i < kNumPartitions; ++i) {
EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec));
}
}
} // namespace
} // namespace xla