Skip to content

Commit

Permalink
Added e2e test for mem cpy p2p in a loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Tixxx committed Dec 9, 2024
1 parent 5ec93af commit 9d3a8a4
Showing 1 changed file with 119 additions and 0 deletions.
119 changes: 119 additions & 0 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1526,5 +1526,124 @@ ENTRY entry {
EXPECT_TRUE(executable->has_module());
}

TEST_F(CollectiveOpsTestE2E, MemcpyP2pWhileLoopCorrectness) {
absl::string_view hlo_string = R"(
HloModule MemcpyP2pWhileLoopCorrectness, entry_computation_layout={(bf16[128,96]{1,0})->(bf16[32,384]{1,0}, bf16[32,384]{1,0})}, allow_spmd_sharding_propagation_to_output={true,true}, num_partitions=4
None.4 {
Arg_1.6 = bf16[32,96]{1,0} parameter(1)
Arg_0.5 = bf16[32,96]{1,0} parameter(0)
collective-permute.9 = bf16[32,96]{1,0} collective-permute(Arg_0.5), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,0}}
constant.7 = bf16[] constant(2)
broadcast.8 = bf16[32,96]{1,0} broadcast(constant.7), dimensions={}
multiply.10 = bf16[32,96]{1,0} multiply(Arg_0.5, broadcast.8)
ROOT tuple.11 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(collective-permute.9, multiply.10)
} // None.4
region_0.12 {
arg_tuple.13 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0
constant.17 = s32[] constant(1)
add.21 = s32[] add(get-tuple-element.14, constant.17)
get-tuple-element.15 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=1
get-tuple-element.16 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.13), index=2
call.18 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(get-tuple-element.15, get-tuple-element.16), to_apply=None.4
get-tuple-element.19 = bf16[32,96]{1,0} get-tuple-element(call.18), index=0
get-tuple-element.20 = bf16[32,96]{1,0} get-tuple-element(call.18), index=1
ROOT tuple.22 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(add.21, get-tuple-element.19, get-tuple-element.20)
} // region_0.12
region_1.23 {
arg_tuple.24 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) parameter(0)
get-tuple-element.26 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=1
get-tuple-element.27 = bf16[32,96]{1,0} get-tuple-element(arg_tuple.24), index=2
get-tuple-element.25 = s32[] get-tuple-element(arg_tuple.24), index=0
constant.28 = s32[] constant(3)
ROOT compare.29 = pred[] compare(get-tuple-element.25, constant.28), direction=LT
} // region_1.23
shmap_body.30 {
constant.32 = s32[] constant(0)
Arg_0.31 = bf16[32,96]{1,0} parameter(0)
constant.33 = bf16[] constant(0)
broadcast.34 = bf16[32,96]{1,0} broadcast(constant.33), dimensions={}
tuple.35 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(constant.32, Arg_0.31, broadcast.34)
while.36 = (s32[], bf16[32,96]{1,0}, bf16[32,96]{1,0}) while(tuple.35), condition=region_1.23, body=region_0.12
get-tuple-element.37 = s32[] get-tuple-element(while.36), index=0
get-tuple-element.38 = bf16[32,96]{1,0} get-tuple-element(while.36), index=1
get-tuple-element.39 = bf16[32,96]{1,0} get-tuple-element(while.36), index=2
ROOT tuple.40 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) tuple(get-tuple-element.38, get-tuple-element.39)
} // shmap_body.30
ENTRY main.49 {
Arg_0.1 = bf16[128,96]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}
custom-call.2 = bf16[128,96]{1,0} custom-call(Arg_0.1), custom_call_target="Sharding", sharding={devices=[4,1]<=[4]}
custom-call.3 = bf16[32,96]{1,0} custom-call(custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}
call.41 = (bf16[32,96]{1,0}, bf16[32,96]{1,0}) call(custom-call.3), to_apply=shmap_body.30
get-tuple-element.42 = bf16[32,96]{1,0} get-tuple-element(call.41), index=0
custom-call.44 = bf16[32,96]{1,0} custom-call(get-tuple-element.42), custom_call_target="Sharding", sharding={manual}
custom-call.45 = bf16[32,384]{1,0} custom-call(custom-call.44), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
get-tuple-element.43 = bf16[32,96]{1,0} get-tuple-element(call.41), index=1
custom-call.46 = bf16[32,96]{1,0} custom-call(get-tuple-element.43), custom_call_target="Sharding", sharding={manual}
custom-call.47 = bf16[32,384]{1,0} custom-call(custom-call.46), custom_call_target="SPMDShardToFullShape", sharding={devices=[1,4]<=[4]}
ROOT tuple.48 = (bf16[32,384]{1,0}, bf16[32,384]{1,0}) tuple(custom-call.45, custom-call.47)
} // main.49
)";

const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions);

HloModuleConfig config = GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_use_memcpy_local_p2p(true);
config.set_debug_options(opts);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string, config));
auto fake_arguments = xla::MakeFakeArguments(module.get()).value();
std::vector<Literal*> fake_ptrs(fake_arguments.size());
for (int i = 0; i < fake_arguments.size(); ++i) {
fake_ptrs[i] = &fake_arguments[i];
}

DeviceAssignment assn(/*replica_count=*/kNumReplicas,
/*computation_count=*/kNumPartitions);
for (int64_t i = 0; i < kNumPartitions; ++i) {
assn(0, i) = i;
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(
std::move(module), fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(results.size(), kNumPartitions);

HloModuleConfig ref_config =
GetModuleConfigForTest(kNumReplicas, kNumPartitions);
auto ref_opts = GetDebugOptionsForTest();
ref_opts.set_xla_gpu_use_memcpy_local_p2p(false);
ref_config.set_debug_options(ref_opts);
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_string, ref_config));
auto fake_ref_arguments = xla::MakeFakeArguments(ref_module.get()).value();
std::vector<Literal*> ref_fake_ptrs(fake_ref_arguments.size());
for (int i = 0; i < fake_ref_arguments.size(); ++i) {
ref_fake_ptrs[i] = &fake_ref_arguments[i];
}

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> ref_results,
HloTestBase::ExecuteReplicated(
std::move(ref_module), ref_fake_ptrs, kNumPartitions, &assn,
/*run_hlo_passes=*/true, /*use-threads=*/true));
ASSERT_EQ(ref_results.size(), kNumPartitions);
ErrorSpec error_spec{1e-5, 1e-5};
// Expect same results with and without pipelining of collectives.
for (int i = 0; i < kNumPartitions; ++i) {
EXPECT_TRUE(LiteralTestUtil::Near(ref_results[i], results[i], error_spec));
}
}

} // namespace
} // namespace xla

0 comments on commit 9d3a8a4

Please sign in to comment.