diff --git a/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index cfc75be5b8e1b..6356a13bfce54 100644 --- a/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -85,7 +85,10 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( pipeline.AddPass("horizontal-loop-fusion-for-copy"); // To fuse the copy. sub_pipeline.AddPass(device_description); - sub_pipeline.AddPass(device_description, "copy_"); + // Make sure to run HorizontalLoopFusion only inside the entry computation. + // Fusing copies outside of the entry computation can break buffer assignment! + sub_pipeline.AddPass(device_description, "copy_", + /*only_entry_computation=*/true); sub_pipeline.AddPass(); pipeline.AddPass(); return pipeline; diff --git a/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 3145873e18449..a6f12834b12d0 100644 --- a/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -40,7 +40,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" #include "xla/layout_util.h" -#include "xla/primitive_util.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" @@ -737,11 +736,19 @@ absl::StatusOr HorizontalLoopFusion::Run( const absl::flat_hash_set& execution_threads) { VLOG(2) << "Run horizontal fusion."; - // Run on the entry computation is actually enough. - TF_ASSIGN_OR_RETURN(bool changed, - RunOnComputation(module->entry_computation())); + bool any_changed = false; + if (only_entry_computation_) { + TF_ASSIGN_OR_RETURN(any_changed, + RunOnComputation(module->entry_computation())); + } else { + for (HloComputation* computation : + GetFusibleComputations(*module, execution_threads)) { + TF_ASSIGN_OR_RETURN(bool changed, RunOnComputation(computation)); + any_changed |= changed; + } + } - if (changed) { + if (any_changed) { // Correctly set element_size_in_bits for any sub-byte added slice and // concatenate instructions TF_ASSIGN_OR_RETURN( @@ -750,7 +757,7 @@ absl::StatusOr HorizontalLoopFusion::Run( module)); } - return changed; + return any_changed; } } // namespace gpu diff --git a/xla/service/gpu/transforms/horizontal_loop_fusion.h b/xla/service/gpu/transforms/horizontal_loop_fusion.h index d7add7aff840d..fa858192dc2b6 100644 --- a/xla/service/gpu/transforms/horizontal_loop_fusion.h +++ b/xla/service/gpu/transforms/horizontal_loop_fusion.h @@ -126,8 +126,11 @@ namespace gpu { class HorizontalLoopFusion : public HloModulePass { public: explicit HorizontalLoopFusion(const se::DeviceDescription& device_description, - absl::string_view prefix = "") - : device_description_(device_description), prefix_(prefix) {} + absl::string_view prefix = "", + bool only_entry_computation = false) + : device_description_(device_description), + prefix_(prefix), + only_entry_computation_(only_entry_computation) {} absl::string_view name() const override { return "horizontal_loop_fusion"; } @@ -141,6 +144,7 @@ class HorizontalLoopFusion : public HloModulePass { const se::DeviceDescription& device_description_; std::string prefix_; + bool only_entry_computation_; }; } // namespace gpu diff --git a/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index c91c12dc13b01..f79c4a60b59ff 100644 --- a/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc +++ b/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -989,23 +989,25 @@ TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) { HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } -TEST_F(HorizontalLoopFusionTest, FuseDifferentInstructionCounts) { +TEST_F(HorizontalLoopFusionTest, DoFusionInsideWhileLoop) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( -f { - p = s8[] parameter(0) - b = s8[1] bitcast(p) - } +b { + a = (s8[]) parameter(0) + b = s8[] get-tuple-element(a), index=0 + c = s8[] add(b, b) + d = s8[] multiply(b, b) + e = s8[] subtract(c, d) + t = tuple(e) +} -g { - p = s8[] parameter(0) +c { + p = (s8[]) parameter(0) + r = pred[] constant(true) } e { - p0 = s8[] parameter(0) - p1 = s8[] parameter(1) - a = s8[1] fusion(p0), kind=kLoop, calls=f - b = s8[] fusion(p1), kind=kLoop, calls=g - t = tuple(a, b) + p = (s8[]) parameter(0) + r = (s8[]) while(p), condition=c, body=b })")); EXPECT_TRUE( @@ -1094,6 +1096,52 @@ e { HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } +TEST_F(HorizontalLoopFusionTest, DontFuseCopiesInsideWhileLoops) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module_main, entry_computation_layout={(f32[10]{0}, f32[20]{0})->(s32[], f32[10]{0}, f32[20]{0})} + +f { + param0 = f32[10]{0} parameter(0) + reverse = f32[10]{0} reverse(param0), dimensions={0} + param1 = f32[20]{0} parameter(1) + param2 = s32[] parameter(2) + dynamic_slice = f32[10]{0} dynamic-slice(param1, param2), dynamic_slice_sizes={10} + ROOT res = f32[10]{0} add(reverse, dynamic_slice) +} + +body { + p0 = (s32[], f32[10]{0}, f32[20]{0}) parameter(0) + iter = s32[] get-tuple-element(p0), index=0 + one = s32[] constant(1) + next_iter = s32[] add(iter, one) + a = f32[10]{0} get-tuple-element(p0), index=1 + b = f32[20]{0} get-tuple-element(p0), index=2 + next_a = f32[10]{0} fusion(a, b, iter), kind=kLoop, calls=f + copy.0 = f32[10]{0} copy(next_a) + next_b = f32[20]{0} reverse(b), dimensions={0} + copy.1 = f32[20]{0} copy(next_b) + ROOT r = (s32[], f32[10]{0}, f32[20]{0}) tuple(next_iter, copy.0, copy.1) +} + +cond { + p = (s32[], f32[10]{0}, f32[20]{0}) parameter(0) + i = s32[] get-tuple-element(p), index=0 + bound = s32[] constant(10) + ROOT res.1 = pred[] compare(i, bound), direction=LT +} + +ENTRY main { + zero = s32[] constant(0) + p0.1 = f32[10]{0} parameter(0) + p1.0 = f32[20]{0} parameter(1) + while_init = (s32[], f32[10]{0}, f32[20]{0}) tuple(zero, p0.1, p1.0) + ROOT while = (s32[], f32[10]{0}, f32[20]{0}) while(while_init), condition=cond, body=body +})")); + HorizontalLoopFusion loop_fusion(device_description_, /*prefix=*/"", + /*only_entry_computation=*/true); + EXPECT_FALSE(loop_fusion.Run(module.get()).value()); +} + } // namespace } // namespace gpu } // namespace xla