Skip to content

Commit

Permalink
[XLA:GPU] Mark collectives in formatting ops as pipelined.
Browse files Browse the repository at this point in the history
AppendPipelinedInstruction function was not called for formatting ops. This caused the pipelined collective to not be marked as pipelined in the backend config.

PiperOrigin-RevId: 700339573
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Nov 26, 2024
1 parent 00b1e80 commit f595722
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 10 deletions.
29 changes: 19 additions & 10 deletions xla/service/collective_pipeliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,9 @@ template <typename Comp>
absl::StatusOr<HloInstruction*> CloneBackwardChain(
Comp& target_computation, const WhileMoveInfo& move_info,
InstructionMap& clone_map, int64_t loop_iter_idx, int64_t& next_channel_id,
LoopVariantParameterInfo* loop_variant_parameter_info = nullptr) {
LoopVariantParameterInfo* loop_variant_parameter_info = nullptr,
CollectivePipeliner::HloPostprocessor postprocess_pipelined_ops =
std::nullopt) {
std::vector<HloInstruction*> to_clone(move_info.formatting_ops.begin(),
move_info.formatting_ops.end());
to_clone.push_back(move_info.collectives_to_move[0]);
Expand All @@ -715,6 +717,9 @@ absl::StatusOr<HloInstruction*> CloneBackwardChain(
TF_RETURN_IF_ERROR(UpdateControlDependencies(chain_op, cloned, clone_map));
UpdateInstructionChannelId(cloned, next_channel_id);
clone_map[chain_op] = cloned;
if (postprocess_pipelined_ops.has_value()) {
TF_RETURN_IF_ERROR((*postprocess_pipelined_ops)(cloned));
}
last_cloned = cloned;
if (loop_variant_parameter_info != nullptr &&
chain_op->opcode() == HloOpcode::kGetTupleElement &&
Expand Down Expand Up @@ -1913,6 +1918,9 @@ absl::Status TransformLoopForward(
formatting_op->CloneWithNewOperands(formatting_op->shape(),
new_operands));
cloned_map[formatting_op] = processed;
if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(processed));
}
}
return processed;
};
Expand Down Expand Up @@ -2721,10 +2729,11 @@ static absl::Status TransformLoopBackward(
loop_analysis.GetMoveInfos()[i].collectives_to_move[0];
TF_ASSIGN_OR_RETURN(
new_init_operands[idx],
CloneBackwardChain(*while_loop->parent(),
loop_analysis.GetMoveInfos()[i], chain_clone_map,
*loop_analysis.GetLoopIterationIdx(),
next_channel_id));
CloneBackwardChain(
*while_loop->parent(), loop_analysis.GetMoveInfos()[i],
chain_clone_map, *loop_analysis.GetLoopIterationIdx(),
next_channel_id, /*loop_variant_parameter_info=*/nullptr,
post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx]));
Expand Down Expand Up @@ -2774,11 +2783,11 @@ static absl::Status TransformLoopBackward(
if (it != collective_to_move_map.end()) {
TF_ASSIGN_OR_RETURN(
cloned_instr,
CloneBackwardChain(body_builder,
loop_analysis.GetMoveInfos()[it->second],
collective_to_move_clone_map,
*loop_analysis.GetLoopIterationIdx(),
next_channel_id, &loop_variant_parameter_info));
CloneBackwardChain(
body_builder, loop_analysis.GetMoveInfos()[it->second],
collective_to_move_clone_map,
*loop_analysis.GetLoopIterationIdx(), next_channel_id,
&loop_variant_parameter_info, post_processing_fn));

if (post_processing_fn.has_value()) {
TF_RETURN_IF_ERROR((*post_processing_fn)(cloned_instr));
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3107,6 +3107,7 @@ cc_library(
deps = [
":backend_configs_cc",
"//xla/hlo/ir:hlo",
"//xla/service:collective_ops_utils",
"//xla/service:collective_utils",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/log",
Expand All @@ -3125,6 +3126,7 @@ xla_cc_test(
":gpu_hlo_schedule",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/pass:hlo_pass_pipeline",
"//xla/hlo/transforms:hlo_dce",
"//xla/hlo/utils:hlo_query",
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/gpu_collective_combiner_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/collective_utils.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/stream_executor/device_description.h"
Expand Down Expand Up @@ -70,6 +71,9 @@ int64_t ComputeSuggestedCombinerThreshold(
}

absl::Status AppendPipelinedInstruction(HloInstruction* instr) {
if (!IsCollective(instr)) {
return absl::OkStatus();
}
TF_ASSIGN_OR_RETURN(auto config,
instr->backend_config<gpu::GpuBackendConfig>());
config.mutable_collective_backend_config()->set_is_pipelined(true);
Expand Down
187 changes: 187 additions & 0 deletions xla/service/gpu/gpu_collective_combiner_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/pass/hlo_pass_fix.h"
#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
#include "xla/hlo/utils/hlo_query.h"
Expand Down Expand Up @@ -230,6 +231,95 @@ TEST_F(CollectiveCombinerUtilsTest,
});
}

TEST_F(CollectiveCombinerUtilsTest,
AppendPipelinedInstructionForwardFormattingOps) {
// This is just a canonical IR which makes it easy to pipeline a collective
// forward – in this example AllReduce.
absl::string_view kHloText = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
current-loop-index = s32[] get-tuple-element(param), index=0
output-buffer = bf16[3,8,128] get-tuple-element(param), index=1
input-buffer = bf16[3,8,128] get-tuple-element(param), index=2
constant.1 = s32[] constant(1)
next-loop-index = s32[] add(current-loop-index, constant.1)
constant.0 = s32[] constant(0)
sliced-input-buffer = bf16[1,8,128] dynamic-slice(input-buffer,
current-loop-index, constant.0, constant.0), dynamic_slice_sizes={1,8,128}
all-reduce = bf16[1,8,128] all-reduce(sliced-input-buffer),
replica_groups={}, to_apply=add, channel_id=1
all-reduce.1 = bf16[1,8,128] all-reduce(all-reduce),
replica_groups={}, to_apply=add, channel_id=2
dynamic-update-slice = bf16[3,8,128] dynamic-update-slice(output-buffer,
all-reduce.1, current-loop-index, constant.0, constant.0)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(next-loop-index,
dynamic-update-slice, input-buffer)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple),
condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
// This config is taken from the gpu_compiler.cc configuration of the forward
// pipeliner.
CollectivePipeliner::Config config{
/*level_to_operate_on=*/0,
/*max_pipelining_per_loop=*/INT64_MAX,
/*last_run=*/true,
/*pipeline_use_tree=*/false,
/*process_different_sized_ops=*/true,
/*pipelining_direction=*/
CollectivePipeliner::PipeliningDirection::kForward,
/*should_process=*/HloPredicateIsOp<HloOpcode::kAllReduce>,
/*acceptable_formatting=*/HloPredicateTrue,
/*reuse_pipelined_op_buffer=*/HloPredicateFalse,
};
config.postprocess_pipelined_ops = AppendPipelinedInstruction;

HloPassPipeline pipeline("collective-pipeliner");
pipeline.AddPass<CollectivePipeliner>(config);
pipeline.AddPass<HloPassFix<HloDCE>>(
/*remove_cross_partition_collective_ops=*/true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
EXPECT_TRUE(changed);

hlo_query::ForEachInstructionWithOpcode(
*module, HloOpcode::kAllReduce, [](HloInstruction* instr) {
EXPECT_TRUE(instr->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_pipelined());
});

hlo_query::ForEachInstructionWithPred(
*module, HloPredicateIsNotOp<HloOpcode::kAllReduce>,
[](HloInstruction* instr) {
EXPECT_FALSE(instr->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_pipelined());
});
}

TEST_F(CollectiveCombinerUtilsTest,
AppendPipelinedInstructionAppendsPipelinedInstructionInfoBackward) {
// This is just the simple IR which makes it easy for the pipeliner to
Expand Down Expand Up @@ -317,5 +407,102 @@ TEST_F(CollectiveCombinerUtilsTest,
});
}

TEST_F(CollectiveCombinerUtilsTest,
AppendPipelinedInstructionBackwardFormattingOps) {
// This is just the simple IR which makes it easy for the pipeliner to
// pipeline a collective. The pipelined collective is AllGather so the main
// complexity comes from a fact that we have to slice it at the end of the
// loop (so that we can gather it again in the next iteration).
absl::string_view kHloText = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
current-loop-index = s32[] get-tuple-element(param), index=0
output-buffer = bf16[3,8,128] get-tuple-element(param), index=1
input-buffer = bf16[3,8,128] get-tuple-element(param), index=2
constant.1 = s32[] constant(1)
next-loop-index = s32[] add(current-loop-index, constant.1)
constant.0 = s32[] constant(0)
all-reduce = bf16[3,8,128] all-reduce(input-buffer), to_apply=add, replica_groups={}
sliced-input-buffer = bf16[1,8,128] dynamic-slice(all-reduce,
current-loop-index, constant.0, constant.0), dynamic_slice_sizes={1,8,128}
all-gather = bf16[3,8,128] all-gather(sliced-input-buffer), dimensions={0}
dynamic-update-slice = bf16[3,8,128] dynamic-update-slice(output-buffer,
all-gather, current-loop-index, constant.0, constant.0)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(next-loop-index,
dynamic-update-slice, input-buffer)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple),
condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
// This config is taken from the gpu_compiler.cc configuration of the backward
// pipeliner.
CollectivePipeliner::Config config{
/*level_to_operate_on=*/0,
/*max_pipelining_per_loop=*/INT64_MAX,
/*last_run=*/true,
/*pipeline_use_tree=*/false,
/*process_different_sized_ops=*/true,
/*pipelining_direction=*/
CollectivePipeliner::PipeliningDirection::kBackward,
/*should_process=*/HloPredicateIsOp<HloOpcode::kAllGather>,
/*acceptable_formatting=*/HloPredicateTrue,
/*reuse_pipelined_op_buffer=*/HloPredicateFalse,
/*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse,
/*should_allow_control_dependencies=*/false,
/*postprocess_backward_peeled_op=*/std::nullopt,
/*postprocess_backward_rotated_op=*/std::nullopt,
/*should_add_loop_invariant_op_in_chain=*/true,
};
config.postprocess_pipelined_ops = AppendPipelinedInstruction;

HloPassPipeline pipeline("collective-pipeliner");
pipeline.AddPass<CollectivePipeliner>(config);
pipeline.AddPass<HloPassFix<HloDCE>>(
/*remove_cross_partition_collective_ops=*/true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
EXPECT_TRUE(changed);

hlo_query::ForEachInstructionWithPred(
*module, HloPredicateIsOp<HloOpcode::kAllGather, HloOpcode::kAllReduce>,
[](HloInstruction* instr) {
EXPECT_TRUE(instr->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_pipelined());
});

hlo_query::ForEachInstructionWithPred(
*module,
HloPredicateIsNotOp<HloOpcode::kAllGather, HloOpcode::kAllReduce>,
[](HloInstruction* instr) {
EXPECT_FALSE(instr->backend_config<GpuBackendConfig>()
->collective_backend_config()
.is_pipelined());
});
}

} // namespace
} // namespace xla::gpu

0 comments on commit f595722

Please sign in to comment.