From 43517d94ad963a96a2a308b7b33d77ecd7de4b4a Mon Sep 17 00:00:00 2001 From: Mehrdad Khani Date: Wed, 20 Nov 2024 15:49:42 -0800 Subject: [PATCH 1/4] [XLA:MSA] Fixes a bug in GetInefficientAllocationSites(allocation_values). The function was previously assuming allocation_values can never be empty. PiperOrigin-RevId: 698548828 --- .../memory_space_assignment/algorithm.cc | 5 +- .../memory_space_assignment_test.cc | 60 +++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index efcd7e12153c37..b4d98b331f0d53 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -2006,7 +2006,10 @@ MsaAlgorithm::GetInefficientAllocationSites( return {}; } - int64_t size = allocation_values.at(0).size(); + int64_t size = 0; + if (!allocation_values.empty()) { + size = allocation_values.at(0).size(); + } if (VLOG_IS_ON(3)) { for (const AllocationValue& allocation_value : allocation_values) { diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 647e52baf552d3..096f6746ed4737 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/memory_space_assignment/memory_space_assignment.h" +#include + #include #include #include @@ -1309,6 +1311,64 @@ ENTRY entry { EXPECT_LT(copy_done_index, negate4_index); } +// Added for b/372277844#comment15 that was introduced when the allocation +// failed while trying to convert a sync slice to an async one, but not due to +// the conversion itself. In this case, associated buffer with the slice +// (p0_copy) is too large to fit in alternate memory. Hence, the +// allocation_values will be empty in retries, previously causing a crash in +// MsaAlgorithm::GetInefficientAllocationSites(). +TEST_F(MemorySpaceAssignmentTest, SyncReplacementLargeBuffers) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[10,2,3]{2,1,0} parameter(1) + p0_copy = f32[10,2,3]{2,1,0} copy(p0) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + slice = f32[1,2,3] slice(p0_copy), slice={[0:1], [0:2], [0:3]} + ROOT concat = f32[11,2,3] concatenate(negate5, slice), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 64; + options.max_retries = 2; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + // Force the allocation of p0_copy to fail for the concat use with only + // AllocationResult::kFailRequiresUncommit. This means that while the slice + // replacement was successful, the algorithm must retry one more time without + // sync slice conversion target, so that maybe other constraints of the + // allocation can be satisfied. + options.allocation_result_modifier_testing_fn = + [](const AllocationRequest& request, AllocationResult& result) { + if (request.allocation_value->defining_instruction()->name() == + "p0_copy" && + request.use->hlo_use.instruction->name() == "concat") { + result = AllocationResult::kFailRequiresUncommit; + } + }; + // options.inefficient_use_to_copy_ratio must be greater than 0 and the cost + // model must be set to trigger the inefficient allocation site logic. + options.inefficient_use_to_copy_ratio = 1.0; + AssignMemorySpaceUsingCostAnalysis(module.get(), options); + + HloInstruction* p0_copy = FindInstruction(module.get(), "p0_copy"); + ASSERT_NE(p0_copy, nullptr); + HloInstruction* concat = FindInstruction(module.get(), "concat"); + ASSERT_NE(concat, nullptr); + EXPECT_THAT(concat->operand(1), op::Slice(p0_copy)); +} + // Added for b/376869021, which surfaced when we tried to convert a sync slice // that had to extend the allocation of its operand in the alternate memory. In // this test, we expect the slice0 operand (p0_copy) maintain a valid allocation From 5cc5c9c80d94d2616f6e4f910f091d3be42e76c1 Mon Sep 17 00:00:00 2001 From: Farzin Houshmand Date: Wed, 20 Nov 2024 15:59:48 -0800 Subject: [PATCH 2/4] Adding step to constant_value and add support for multiplication while recursively calculating the range of an expression. PiperOrigin-RevId: 698551804 --- xla/service/BUILD | 10 +++ xla/service/constant_value.h | 3 + xla/service/value_range.cc | 113 ++++++++++++++++++++++++++------ xla/service/value_range.h | 23 ++++++- xla/service/value_range_test.cc | 112 ++++++++++++++++++++++++++++--- 5 files changed, 229 insertions(+), 32 deletions(-) diff --git a/xla/service/BUILD b/xla/service/BUILD index ffbb6b023151fb..60ebb47d4c4619 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -325,8 +325,13 @@ cc_library( hdrs = ["value_range.h"], deps = [ ":constant_value", + "//xla:comparison_util", + "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", ], ) @@ -334,12 +339,17 @@ xla_cc_test( name = "value_range_test", srcs = ["value_range_test.cc"], deps = [ + ":constant_value", ":hlo_module_config", ":value_range", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/constant_value.h b/xla/service/constant_value.h index f6b8477c80acc7..418fcf56ef8e2e 100644 --- a/xla/service/constant_value.h +++ b/xla/service/constant_value.h @@ -46,6 +46,9 @@ class ConstantValue { static ConstantValue GetOne(int32_t bitwidth, bool is_signed) { return ConstantValue(1, bitwidth, is_signed); } + static ConstantValue Get(int64_t value, int32_t bitwidth, bool is_signed) { + return ConstantValue(absl::bit_cast(value), bitwidth, is_signed); + } static ConstantValue GetSigned(int64_t value, int32_t bitwidth) { return ConstantValue(absl::bit_cast(value), bitwidth, /*is_signed=*/true); diff --git a/xla/service/value_range.cc b/xla/service/value_range.cc index 178ebcb33f31db..b976a403a33f8c 100644 --- a/xla/service/value_range.cc +++ b/xla/service/value_range.cc @@ -15,10 +15,19 @@ limitations under the License. #include "xla/service/value_range.h" +#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" +#include "xla/service/constant_value.h" namespace xla { @@ -40,15 +49,34 @@ std::string Range::ToString() const { if (IsEmpty()) { return std::string("Empty"); } - return absl::StrCat("min: ", min_.ToString(), " max: ", max_.ToString()); + return absl::StrCat( + "min: ", min_.ToString(), " max: ", max_.ToString(), + " step: ", IsStepKnown() ? step_.value().ToString() : "Unknown"); +} + +std::optional FindStepForBinaryOp(const Range& lhs, + const Range& rhs) { + if (!lhs.IsStepKnown() || !rhs.IsStepKnown()) { + return std::nullopt; + } + if (lhs.IsSingleValue()) { + return rhs.step(); + } + if (rhs.IsSingleValue()) { + return lhs.step(); + } + if (lhs.step().eq(rhs.step())) { + return lhs.step(); + } + return std::nullopt; } // Identify the value ranges of a scalar HLO with a integer type. It returns // a range of values that the instruction can have. Range RecursivelyIdentifyRange( const HloInstruction* instr, - const absl::flat_hash_map& - predefined_ranges) { + const absl::flat_hash_map& predefined_ranges, + const HloAliasAnalysis* alias_analysis) { // Non scalar or non-integer HLO. Abort. if ((!instr->shape().IsInteger() && instr->shape().element_type() != PRED) || instr->shape().dimensions_size() != 0) { @@ -60,14 +88,26 @@ Range RecursivelyIdentifyRange( VLOG(5) << "Found range! " << it->second.max().GetSignedValue() << " " << it->second.min().GetSignedValue(); return it->second; + } else if (alias_analysis != nullptr) { + auto value_set = + alias_analysis->dataflow_analysis().GetFlattenedValueSet(instr); + for (const auto& value : value_set.TakeValues()) { + auto it = predefined_ranges.find(value->defining_instruction()); + if (it != predefined_ranges.end()) { + VLOG(5) << "Found range in defining instruction! " + << it->second.max().GetSignedValue() << " " + << it->second.min().GetSignedValue(); + return it->second; + } + } } switch (instr->opcode()) { case HloOpcode::kCompare: { VLOG(5) << "Handling Compare"; - Range lhs = - RecursivelyIdentifyRange(instr->operand(0), predefined_ranges); - Range rhs = - RecursivelyIdentifyRange(instr->operand(1), predefined_ranges); + Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + alias_analysis); + Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); // Only kLt supported right now. @@ -105,11 +145,13 @@ Range RecursivelyIdentifyRange( const int64_t value = *instr->literal().GetFirstInteger(); return Range{ConstantValue::GetSigned(value, bitwidth), ConstantValue::GetSigned(value, bitwidth), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), /*is_linear=*/true}; } const uint64_t value = *instr->literal().GetFirstInteger(); return Range{ConstantValue::GetUnsigned(value, bitwidth), ConstantValue::GetUnsigned(value, bitwidth), + ConstantValue::GetOne(/*bitwidth=*/1, /*is_signed=*/false), /*is_linear=*/true}; } case HloOpcode::kAdd: { @@ -117,10 +159,10 @@ Range RecursivelyIdentifyRange( return Range{}; } VLOG(5) << "Handling Add"; - Range lhs = - RecursivelyIdentifyRange(instr->operand(0), predefined_ranges); - Range rhs = - RecursivelyIdentifyRange(instr->operand(1), predefined_ranges); + Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + alias_analysis); + Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); if (lhs.IsEmpty() || rhs.IsEmpty()) { @@ -132,31 +174,61 @@ Range RecursivelyIdentifyRange( VLOG(5) << "Add wrapped"; return Range{}; } - return Range{min, max, lhs.IsLinear() && rhs.IsLinear()}; + return Range{min, max, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}; + } + case HloOpcode::kMultiply: { + if (!instr->shape().IsInteger()) { + return Range{}; + } + VLOG(5) << "Handling Multiply"; + Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + alias_analysis); + Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + alias_analysis); + VLOG(5) << "Returned Rhs: " << rhs.ToString() + << " Lhs: " << lhs.ToString(); + if (lhs.IsEmpty() || rhs.IsEmpty()) { + return Range{}; + } + // We only handle multiplication of a single value with a range. + if (!lhs.IsSingleValue() && !rhs.IsSingleValue()) { + return Range{}; + } + ConstantValue single_value = lhs.IsSingleValue() ? lhs.min() : rhs.min(); + ConstantValue min = lhs.IsSingleValue() ? rhs.min().mul(single_value) + : lhs.min().mul(single_value); + ConstantValue max = lhs.IsSingleValue() ? rhs.max().mul(single_value) + : lhs.max().mul(single_value); + return Range{min, max, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}; } case HloOpcode::kSelect: { VLOG(5) << "Handling Select"; const HloInstruction* cmp = instr->operand(0); - Range cmp_range = RecursivelyIdentifyRange(cmp, predefined_ranges); + Range cmp_range = + RecursivelyIdentifyRange(cmp, predefined_ranges, alias_analysis); // Support only when the select has a constant value as condition. if (cmp_range.IsEmpty() || !cmp_range.IsSingleValue()) { VLOG(5) << "Select failed"; return Range{}; } if (cmp_range.GetSingleSignedValue() == 0) { - return RecursivelyIdentifyRange(instr->operand(2), predefined_ranges); + return RecursivelyIdentifyRange(instr->operand(2), predefined_ranges, + alias_analysis); } - return RecursivelyIdentifyRange(instr->operand(1), predefined_ranges); + return RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + alias_analysis); } case HloOpcode::kSubtract: { if (!instr->shape().IsInteger()) { return Range{}; } VLOG(5) << "Handling Subtract"; - Range lhs = - RecursivelyIdentifyRange(instr->operand(0), predefined_ranges); - Range rhs = - RecursivelyIdentifyRange(instr->operand(1), predefined_ranges); + Range lhs = RecursivelyIdentifyRange(instr->operand(0), predefined_ranges, + alias_analysis); + Range rhs = RecursivelyIdentifyRange(instr->operand(1), predefined_ranges, + alias_analysis); VLOG(5) << "Returned Rhs: " << rhs.ToString() << " Lhs: " << lhs.ToString(); if (lhs.IsEmpty() || rhs.IsEmpty()) { @@ -168,7 +240,8 @@ Range RecursivelyIdentifyRange( VLOG(5) << "Subtract wrapped"; return Range{}; } - return Range{min, max, lhs.IsLinear() && rhs.IsLinear()}; + return Range{min, max, FindStepForBinaryOp(lhs, rhs), + lhs.IsLinear() && rhs.IsLinear()}; } default: break; diff --git a/xla/service/value_range.h b/xla/service/value_range.h index daf55ce2331e75..c7cf49d9823c95 100644 --- a/xla/service/value_range.h +++ b/xla/service/value_range.h @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/constant_value.h" namespace xla { @@ -30,14 +32,28 @@ class Range { Range() : min_(ConstantValue::GetZero(/*bitwidth=*/64, /*is_signed=*/false)), max_(ConstantValue::GetZero(/*bitwidth=*/64, /*is_signed=*/false)), + step_(ConstantValue::GetZero(/*bitwidth=*/64, /*is_signed=*/false)), empty_(true), is_linear_(false) {} Range(const ConstantValue& min, const ConstantValue& max, bool is_linear) - : min_(min), max_(max), empty_(false), is_linear_(is_linear) {} + : min_(min), + max_(max), + step_(std::nullopt), + empty_(false), + is_linear_(is_linear) {} + Range(const ConstantValue& min, const ConstantValue& max, + std::optional step, bool is_linear) + : min_(min), + max_(max), + step_(step), + empty_(false), + is_linear_(is_linear) {} // Minimum value of the range. const ConstantValue& min() const { return min_; } // Maximum value of the range. const ConstantValue& max() const { return max_; } + // Step value of the range. + const ConstantValue& step() const { return step_.value(); } // Returns if the range is empty (no value in set). bool IsEmpty() const { return empty_; } // Only one value in set. This means the range is a constant. @@ -48,6 +64,7 @@ class Range { // causing the final value represented by the range in a monotonic way during // loop recursion. bool IsLinear() const { return is_linear_; } + bool IsStepKnown() const { return step_.has_value(); } // If this range represents a single value return that signed value. std::optional GetSingleSignedValue() const; // If this range represents a single value return that unsigned value. @@ -58,6 +75,7 @@ class Range { private: ConstantValue min_; ConstantValue max_; + std::optional step_; bool empty_; bool is_linear_; }; @@ -69,7 +87,8 @@ class Range { // The input HLO needs to be of scalar type and integer. Range RecursivelyIdentifyRange( const HloInstruction* instr, - const absl::flat_hash_map& predefined_ranges); + const absl::flat_hash_map& predefined_ranges, + const HloAliasAnalysis* alias_analysis = nullptr); } // namespace xla diff --git a/xla/service/value_range_test.cc b/xla/service/value_range_test.cc index 1f98c489edc373..415f08df0213f1 100644 --- a/xla/service/value_range_test.cc +++ b/xla/service/value_range_test.cc @@ -19,9 +19,14 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/parser/hlo_parser.h" +#include "xla/service/constant_value.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -30,28 +35,115 @@ class ValueRangeTest : public HloTestBase {}; TEST_F(ValueRangeTest, AddedValue) { constexpr absl::string_view hlo_string = R"( -HloModule module + HloModule module -ENTRY entry { - c0 = s32[] constant(124) - p0 = s32[] parameter(0) - ROOT %a = s32[] add(p0, c0) -} -)"; + ENTRY entry { + c0 = s32[] constant(124) + p0 = s32[] parameter(0) + ROOT %a = s32[] add(p0, c0) + } + )"; auto module = ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* p0 = root->operand(0); absl::flat_hash_map fs; - fs.insert(std::make_pair( - p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), - ConstantValue::GetSigned(5, 32), /*is_linear=*/true})); + fs.insert( + std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), + ConstantValue::GetSigned(5, 32), + ConstantValue::GetOne(32, /*is_signed=*/false), + /*is_linear=*/true})); auto range = RecursivelyIdentifyRange(root, fs); EXPECT_FALSE(range.IsEmpty()); EXPECT_FALSE(range.IsSingleValue()); EXPECT_TRUE(range.IsLinear()); EXPECT_EQ(range.min().GetSignedValue(), 124); EXPECT_EQ(range.max().GetSignedValue(), 129); + EXPECT_EQ(range.step().GetSignedValue(), 1); +} + +TEST_F(ValueRangeTest, MultiplyValue) { + constexpr absl::string_view hlo_string = R"( + HloModule module + + ENTRY entry { + c0 = s32[] constant(1024) + p0 = s32[] parameter(0) + ROOT %a = s32[] multiply(p0, c0) + } + )"; + auto module = + ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); + const HloInstruction* root = module->entry_computation()->root_instruction(); + const HloInstruction* p0 = root->operand(0); + absl::flat_hash_map fs; + fs.insert( + std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), + ConstantValue::GetSigned(5, 32), + ConstantValue::GetOne(32, /*is_signed=*/false), + /*is_linear=*/true})); + auto range = RecursivelyIdentifyRange(root, fs); + EXPECT_FALSE(range.IsEmpty()); + EXPECT_FALSE(range.IsSingleValue()); + EXPECT_TRUE(range.IsLinear()); + EXPECT_EQ(range.min().GetSignedValue(), 0); + EXPECT_EQ(range.max().GetSignedValue(), 5120); + EXPECT_EQ(range.step().GetSignedValue(), 1); +} + +TEST_F(ValueRangeTest, ConstantValueWithConditional) { + constexpr absl::string_view hlo_string = R"( + HloModule module + region1 { + region1_param = s32[] parameter(0) + region1_c0 = s32[] constant(1024) + %add = s32[] add(region1_param, region1_c0) + ROOT out = (s32[], s32[]) tuple(%add, %add) + } + region2 { + region2_param = s32[] parameter(0) + region2_c0 = s32[] constant(1024) + %mult = s32[] multiply(region2_param, region2_c0) + ROOT out = (s32[], s32[]) tuple(%mult, %mult) + } + ENTRY entry { + p0 = s32[] parameter(0) + branch_index = s32[] parameter(1) + ROOT conditional.1 = (s32[], s32[]) conditional(branch_index, p0, p0), branch_computations={region1, region2} + } + )"; + auto module = + ParseAndReturnUnverifiedModule(hlo_string, HloModuleConfig{}).value(); + TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis, + HloAliasAnalysis::Run(module.get())); + HloComputation* region1 = module->GetComputationWithName("region1"); + HloComputation* region2 = module->GetComputationWithName("region2"); + HloInstruction* add = region1->GetInstructionWithName("add"); + HloInstruction* mult = region2->GetInstructionWithName("mult"); + const HloInstruction* p0 = + module->entry_computation()->parameter_instruction(0); + absl::flat_hash_map fs; + fs.insert( + std::make_pair(p0, Range{ConstantValue::GetZero(32, /*is_signed=*/true), + ConstantValue::GetSigned(5, 32), + ConstantValue::GetOne(32, /*is_signed=*/false), + /*is_linear=*/true})); + + auto add_range = RecursivelyIdentifyRange(add, fs, alias_analysis.get()); + EXPECT_FALSE(add_range.IsEmpty()); + EXPECT_FALSE(add_range.IsSingleValue()); + EXPECT_TRUE(add_range.IsLinear()); + EXPECT_EQ(add_range.min().GetSignedValue(), 1024); + EXPECT_EQ(add_range.max().GetSignedValue(), 1029); + EXPECT_EQ(add_range.step().GetSignedValue(), 1); + + auto mult_range = RecursivelyIdentifyRange(mult, fs, alias_analysis.get()); + EXPECT_FALSE(mult_range.IsEmpty()); + EXPECT_FALSE(mult_range.IsSingleValue()); + EXPECT_TRUE(mult_range.IsLinear()); + EXPECT_EQ(mult_range.min().GetSignedValue(), 0); + EXPECT_EQ(mult_range.max().GetSignedValue(), 5120); + EXPECT_EQ(mult_range.step().GetSignedValue(), 1); } TEST_F(ValueRangeTest, AddedValueUnsigned) { From 5a9e8c7624ee1e95392f15f944da5dd1427fa04b Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 20 Nov 2024 16:09:14 -0800 Subject: [PATCH 3/4] [XLA:GPU] Remove RewriteReductionsPass It is unused. PiperOrigin-RevId: 698554807 --- xla/service/gpu/fusions/transforms/BUILD | 1 - xla/service/gpu/fusions/transforms/passes.h | 1 - xla/service/gpu/fusions/transforms/passes.td | 22 ----- .../transforms/tests/rewrite_reductions.mlir | 93 ------------------- 4 files changed, 117 deletions(-) delete mode 100644 xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir diff --git a/xla/service/gpu/fusions/transforms/BUILD b/xla/service/gpu/fusions/transforms/BUILD index 8f2adb7c38a88b..160dd7ea2c4692 100644 --- a/xla/service/gpu/fusions/transforms/BUILD +++ b/xla/service/gpu/fusions/transforms/BUILD @@ -46,7 +46,6 @@ cc_library( "optimize_loops.cc", "peel_loops.cc", "propagate_slice_indices.cc", - "rewrite_reductions.cc", "simplify_affine.cc", "simplify_arith.cc", "unswitch_loops.cc", diff --git a/xla/service/gpu/fusions/transforms/passes.h b/xla/service/gpu/fusions/transforms/passes.h index 99304ed9a1f8da..0573505c3053a7 100644 --- a/xla/service/gpu/fusions/transforms/passes.h +++ b/xla/service/gpu/fusions/transforms/passes.h @@ -54,7 +54,6 @@ std::unique_ptr CreateOptimizeLoopsPass(); std::unique_ptr CreateFuseLoopsPass(); std::unique_ptr CreatePeelLoopsPass(); std::unique_ptr CreatePropagateSliceIndicesPass(); -std::unique_ptr CreateRewriteReductionsPass(); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); std::unique_ptr CreateUnswitchLoopsPass(); diff --git a/xla/service/gpu/fusions/transforms/passes.td b/xla/service/gpu/fusions/transforms/passes.td index 450ceeeca48ae9..6f4cca8ccd7e0d 100644 --- a/xla/service/gpu/fusions/transforms/passes.td +++ b/xla/service/gpu/fusions/transforms/passes.td @@ -236,28 +236,6 @@ def LowerToLLVMPass : ]; } -def RewriteReductionsPass : Pass< - "xla-gpu-rewrite-reductions", "mlir::func::FuncOp"> { - let summary = "Rewrites reductions to pieces that can efficiently be emitted."; - - let description = [{ - This pass rewrites reductions so they can be emitted efficiently. - - For example, a row reduction of 1024 elements to one may be rewritten to two - reductions, the first one to 32 elements, the second one to one element. - This way, the reduction can be emitted as two warp shuffle reduces. - - A column reduction will be rewritten to a transpose followed by a row - reduction. - }]; - - let dependentDialects = [ - "xla::gpu::XlaGpuDialect", - ]; - - let constructor = "CreateRewriteReductionsPass()"; -} - def VectorizeLoadsAndStoresPass : Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> { let summary = "Vectorizes loads and stores."; diff --git a/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir b/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir deleted file mode 100644 index 5f8b9ba5413d84..00000000000000 --- a/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-rewrite-reductions | \ -// RUN: FileCheck %s - -func.func @add(%a: f32, %b: f32) -> f32 { - %0 = arith.addf %a, %b : f32 - return %0 : f32 -} - -func.func @row_reduction(%arg0: tensor<128x1027xf32>) - -> tensor<128xf32> attributes { - xla_gpu.launch_grid = #xla_gpu.launch_grid< - block_counts = [42, 1, 1], - thread_counts = [128, 1, 1] - > - } { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1] combiner=@add - : tensor<128x1027xf32> to tensor<128xf32> - return %0 : tensor<128xf32> -} - -// CHECK: #[[$PAD_AND_RESHAPE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d0, d1 * 128 + d2 * 32 + d3), -// CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 8], d2 in [0, 3], d3 in [0, 31], d1 * 128 + d2 * 32 + d3 in [0, 1026] -// CHECK-LABEL: @row_reduction -// CHECK-SAME: %[[IN:.*]]: tensor<128x1027xf32> -// CHECK: %[[C0:.*]] = arith.constant 0.00 -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex %[[IN]] at #[[$PAD_AND_RESHAPE]] default %[[C0]] -// CHECK: %[[R1:.*]] = xla_gpu.reduce(%[[REINDEXED]]) inits(%[[C0]]) dimensions=[1] combiner=@add -// CHECK: %[[R2:.*]] = xla_gpu.reduce(%[[R1]]) inits(%[[C0]]) dimensions=[2] combiner=@add -// CHECK: %[[R3:.*]] = xla_gpu.reduce(%[[R2]]) inits(%[[C0]]) dimensions=[1] combiner=@add -// CHECK: return %[[R3]] : tensor<128xf32> - -// ----- - -func.func @add(%a: f32, %b: f32) -> f32 { - %0 = arith.addf %a, %b : f32 - return %0 : f32 -} - -func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<2x42x128x32x8xf32>) - -> tensor<2x128xf32> attributes { - xla_gpu.launch_grid = #xla_gpu.launch_grid< - block_counts = [42, 1, 1], - thread_counts = [128, 1, 1] - > - } { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1, 3, 4] combiner=@add - : tensor<2x42x128x32x8xf32> to tensor<2x128xf32> - return %0 : tensor<2x128xf32> -} - -// CHECK-LABEL: @row_reduction_with_major_reduced_dim -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex -// CHECK-SAME: : tensor<2x42x128x32x8xf32> -> tensor<2x42x128x2x4x32xf32> -// CHECK: xla_gpu.reduce(%[[REINDEXED]]) -// CHECK-SAME: dimensions=[1, 3] -// CHECK-SAME: : tensor<2x42x128x2x4x32xf32> - -// ----- - -func.func @add(%a: f32, %b: f32) -> f32 { - %0 = arith.addf %a, %b : f32 - return %0 : f32 -} - -func.func @column(%arg0: tensor<2x32x32xf32>) - -> tensor<2x32xf32> attributes { - xla_gpu.launch_grid = #xla_gpu.launch_grid< - block_counts = [42, 1, 1], - thread_counts = [128, 1, 1] - > - } { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1] combiner=@add - : tensor<2x32x32xf32> to tensor<2x32xf32> - return %0 : tensor<2x32xf32> -} - -// CHECK: #[[$RESHAPE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3) -// CHECK-SAME: d1 * 4 + d2 in [0, 31] -// CHECK: #[[$TRANSPOSE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0, d2, d1) -// CHECK-LABEL: @column -// CHECK-SAME: %[[IN:.*]]: tensor<2x32x32xf32> -// CHECK: %[[C0:.*]] = arith.constant 0.00 -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex %[[IN]] at #[[$RESHAPE]] default %[[C0]] -// CHECK-SAME: -> tensor<2x8x4x32xf32> -// CHECK: %[[R1:.*]] = xla_gpu.reduce(%[[REINDEXED]]) inits(%[[C0]]) dimensions=[1] -// CHECK-SAME: to tensor<2x4x32xf32> -// CHECK: %[[TRANSPOSED:.*]] = xla_gpu.reindex %[[R1]] at #[[$TRANSPOSE]] -// CHECK-SAME: -> tensor<2x32x4xf32> -// CHECK: %[[R2:.*]] = xla_gpu.reduce(%[[TRANSPOSED]]) inits(%[[C0]]) dimensions=[2] -// CHECK: return %[[R2]] : tensor<2x32xf32> From ef97cd26787d85fb90ac6ff4d59875ba3ffc8374 Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Fri, 15 Nov 2024 10:55:25 -0800 Subject: [PATCH 4/4] [XLA:TPU:MSA] * Add support for overriding cross program prefetch behavior. * Add support for filtering buffer intervals based on the uses of the buffer. * Add tests for overriding cross program prefetch behavior * Add tests for expanding filtering criteria. PiperOrigin-RevId: 696938857 --- .../memory_space_assignment/algorithm.cc | 20 ++- .../buffer_interval_comparator.cc | 18 +- .../buffer_interval_comparator.h | 6 +- .../memory_space_assignment.proto | 8 + .../memory_space_assignment_test.cc | 170 +++++++++++++++++- xla/service/memory_space_assignment/options.h | 2 + xla/service/memory_space_assignment/utils.cc | 49 ++++- xla/service/memory_space_assignment/utils.h | 10 ++ 8 files changed, 263 insertions(+), 20 deletions(-) diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index b4d98b331f0d53..04326395c76321 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -282,20 +282,24 @@ std::vector FindCrossProgramPrefetchCandidates( for (const HloBuffer& buffer : alias_analysis.buffers()) { CHECK_GE(buffer.values().size(), 1); const HloValue* value = buffer.values().at(0); + MsaBufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.need_allocation = true; + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) { - MsaBufferInterval interval; - interval.buffer = value; - interval.size = options.size_fn(*value); - interval.start = 0; - interval.end = hlo_live_range.schedule_end_time(); - interval.need_allocation = true; - interval.colocations = {++buffer.values().begin(), buffer.values().end()}; + candidates.emplace_back(interval); + } else if (MemorySpaceAssignmentUtils:: + DoesCrossProgramPrefetchBufferMatchAnyFilter( + options.msa_sort_order_overrides, interval)) { candidates.emplace_back(interval); } } DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator( - hlo_live_range); + hlo_live_range, options.msa_sort_order_overrides); BufferIntervalComparator* comparator = (options.default_cross_program_prefetch_heuristic && options.buffer_interval_comparator diff --git a/xla/service/memory_space_assignment/buffer_interval_comparator.cc b/xla/service/memory_space_assignment/buffer_interval_comparator.cc index fd07e3550e693d..21a0f736dff36a 100644 --- a/xla/service/memory_space_assignment/buffer_interval_comparator.cc +++ b/xla/service/memory_space_assignment/buffer_interval_comparator.cc @@ -115,8 +115,11 @@ MemoryBoundednessBufferIntervalComparator::GetTuple( DefaultCrossProgramPrefetchBufferIntervalComparator:: DefaultCrossProgramPrefetchBufferIntervalComparator( - const HloLiveRange& hlo_live_range) - : BufferIntervalComparator(), hlo_live_range_(hlo_live_range) {} + const HloLiveRange& hlo_live_range, + const MsaSortOrderOverrides& msa_sort_order_overrides) + : BufferIntervalComparator(), + hlo_live_range_(hlo_live_range), + msa_sort_order_overrides_(msa_sort_order_overrides) {} std::string DefaultCrossProgramPrefetchBufferIntervalComparator:: DescribeComparisonCriteria() const { @@ -138,6 +141,10 @@ bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan( DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( const MsaBufferInterval& buffer_interval) { + int64_t priority = + MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority( + msa_sort_order_overrides_, buffer_interval, + /*is_cross_program_prefetch=*/true); auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer); if (sort_data_it == additional_sort_data_.end()) { AdditionalSortData sort_data; @@ -155,9 +162,10 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple( .first; } - return std::make_tuple( - -1 * buffer_interval.size, -1 * sort_data_it->second.cumulative_use_size, - sort_data_it->second.latest_use, buffer_interval.buffer->id()); + return std::make_tuple(priority, -1 * buffer_interval.size, + -1 * sort_data_it->second.cumulative_use_size, + sort_data_it->second.latest_use, + buffer_interval.buffer->id()); } } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/buffer_interval_comparator.h b/xla/service/memory_space_assignment/buffer_interval_comparator.h index 7d6936674f0a5e..5c7a94b6ffd468 100644 --- a/xla/service/memory_space_assignment/buffer_interval_comparator.h +++ b/xla/service/memory_space_assignment/buffer_interval_comparator.h @@ -114,7 +114,8 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator : public BufferIntervalComparator { public: explicit DefaultCrossProgramPrefetchBufferIntervalComparator( - const HloLiveRange& hlo_live_range); + const HloLiveRange& hlo_live_range, + const MsaSortOrderOverrides& msa_sort_order_overrides); ~DefaultCrossProgramPrefetchBufferIntervalComparator() override = default; @@ -128,7 +129,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator // See the value returned by DescribeComparisonCriteria() for the meaning of // each tuple element. using ComparisonTuple = - std::tuple; + std::tuple; struct AdditionalSortData { int64_t latest_use = 0; @@ -140,6 +141,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator absl::flat_hash_map additional_sort_data_; const HloLiveRange& hlo_live_range_; + const MsaSortOrderOverrides& msa_sort_order_overrides_; }; } // namespace memory_space_assignment diff --git a/xla/service/memory_space_assignment/memory_space_assignment.proto b/xla/service/memory_space_assignment/memory_space_assignment.proto index e15d564dac8f35..09ae1382a4031c 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -105,6 +105,11 @@ message HloOperandFilter { // If operand of an instruction is a tuple and indexing into the tuple is // required. optional TupleShapeIndex tuple_index = 5; + // Regex to match the entire instruction HLO. The HLO string is constructed + // using default HloPrintOptions. Refer to the HloPrintOptions class in + // hlo_instruction.h to know more about the format of the HLO string used for + // matching. + optional string instruction_regex = 6; } // Options to override preferred prefetch time for an operand. @@ -154,6 +159,8 @@ message HloPositionMatcher { optional int64 size_gte = 4; // Filters instructions with output size in bytes less or equal to a value. optional int64 size_lte = 5; + // Filters instructions that have a use that matches the filter. + optional HloOperandFilter hlo_use_filter = 6; } // Options to override preferred prefetch time for an operand. @@ -176,6 +183,7 @@ message MsaSortOrderOverride { optional HloPositionMatcher hlo_position_matcher = 1; optional xla.memory_space_assignment.MsaSortOrderOverrideOptions override_options = 2; + optional bool apply_to_cross_program_prefetches = 3; } // Encloses chained override configs. The first config has highest precedence diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 096f6746ed4737..3948bad2b7ea5f 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -5595,6 +5595,77 @@ TEST_F(MemorySpaceAssignmentTest, EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace); } +TEST_F(MemorySpaceAssignmentTest, + MemoryBoundednessOverrideSortOrderByUseAssignFirst) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[3,4]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + tanh0 = f32[3,4]{1,0} tanh(p0) + negate0 = f32[3,4]{1,0} negate(p1) + tanh1 = f32[3,4]{1,0} tanh(tanh0) + negate1 = f32[3,4]{1,0} negate(negate0) + tanh2 = f32[3,4]{1,0} tanh(tanh1) + negate2 = f32[3,4]{1,0} negate(negate1) + tanh3 = f32[3,4]{1,0} tanh(tanh2) + negate3 = f32[3,4]{1,0} negate(negate2) + tanh4 = f32[3,4]{1,0} tanh(tanh3) + negate4 = f32[3,4]{1,0} negate(negate3) + ROOT tuple = (f32[3,4]{1,0}, f32[3,4]{1,0}) tuple(tanh4, negate4) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Override MSA sort order and try to assign all negates to alternate memory + // first. Alternate memory size is enough to fit 2 f32[4,3] tensors at a time. + const std::string text_proto = R"pb( + overrides { + hlo_position_matcher { + hlo_use_filter { instruction_name_regex: "negate(.*)" } + } + override_options { assign_first: true } + })pb"; + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, + ParseTextProto(text_proto)); + + AssignMemorySpaceUsingCostAnalysis( + module.get(), /*memory_space_options_override=*/std::nullopt, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); + // Parameters are in the default memory space. + const HloInstruction* p0 = FindInstruction(module.get(), "p0"); + EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* p1 = FindInstruction(module.get(), "p1"); + EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace); + // Check that all negates are in alternate memory space except negate4. + // negate4 is a program output, so it has to land in default memory. + HloInstruction* negate0 = FindInstruction(module.get(), "negate0"); + EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate1 = FindInstruction(module.get(), "negate1"); + EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate2 = FindInstruction(module.get(), "negate2"); + EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate3 = FindInstruction(module.get(), "negate3"); + EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace); + HloInstruction* negate4 = FindInstruction(module.get(), "negate4"); + EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0"); + EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1"); + EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2"); + EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3"); + EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace); + const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4"); + EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace); +} + TEST_F(MemorySpaceAssignmentTest, SimpleWhileTupleTest) { Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); Shape f32v1 = ShapeUtil::MakeShape(F32, {1}); @@ -10399,9 +10470,12 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - auto preset_assignments = AssignMemorySpace( - module.get(), DefaultMemorySpaceOptions(), - /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + auto options = DefaultMemorySpaceOptions(); + // Enough space to fit the cross-program prefetch for both p0 and p1. + options.max_size_in_bytes = 512; + auto preset_assignments = AssignMemorySpace(module.get(), options, + /*max_prefetch_interval=*/5, + /*min_prefetch_interval=*/2); auto cross_program_prefetches = module->CrossProgramPrefetches(); EXPECT_EQ(cross_program_prefetches.size(), 1); @@ -10455,6 +10529,96 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { EXPECT_TRUE(has_zero_offset_allocations); } +TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) { + // This test is for checking if the cross-program-prefetched buffer is freed + // after its last use and there is an end-of-program prefetch. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + ENTRY CrossProgramPrefetch { + p0 = f32[8,8]{1,0} parameter(0) + p1 = f32[8,2]{1,0} parameter(1) + dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto options = DefaultMemorySpaceOptions(); + const std::string text_proto = R"pb( + overrides { + hlo_position_matcher { + instruction_name_regex: "p(.*)" + instruction_regex: ".*parameter\\(0\\).*" + } + override_options { assign_first: true } + apply_to_cross_program_prefetches: true + })pb"; + TF_ASSERT_OK_AND_ASSIGN(options.msa_sort_order_overrides, + ParseTextProto(text_proto)); + options.max_size_in_bytes = 256; + auto preset_assignments = AssignMemorySpace(module.get(), options, + /*max_prefetch_interval=*/5, + /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + EXPECT_EQ(cross_program_prefetches[0].parameter, 0); + EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({})); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + LOG(ERROR) << "module: " << module->ToString(); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->cross_program_prefetch_index().has_value(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(), + is_end_of_program_prefetch), + 1); + // Also verify that the copy-done for the end-of-program prefetch is the last + // instruction in schedule. + const HloInstruction* last_instruction = + module->schedule() + .sequence(module->entry_computation()) + .instructions()[module->entry_computation()->instruction_count() - 1]; + EXPECT_THAT(last_instruction, op::CopyDone()); + EXPECT_NE(last_instruction, module->entry_computation()->root_instruction()); + // Cross program prefetch would use offset 0 because that's the first + // assignment. Since we are freeing the cross-program prefetch buffer, we + // would also expect to see some of the intermediate computations (one of the + // negate ops) to also get 0 offset allocations. + bool has_zero_offset_allocations = false; + for (auto pos_and_chunk : preset_assignments->chunks()) { + if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate && + pos_and_chunk.second.offset == 0) { + has_zero_offset_allocations = true; + } + } + EXPECT_TRUE(has_zero_offset_allocations); +} + TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) { // This test is for checking if the cross-program-prefetched buffer is freed // after its last use and there is an end-of-program prefetch. diff --git a/xla/service/memory_space_assignment/options.h b/xla/service/memory_space_assignment/options.h index 075b6016fc9cd1..2148784c9d266c 100644 --- a/xla/service/memory_space_assignment/options.h +++ b/xla/service/memory_space_assignment/options.h @@ -283,6 +283,8 @@ struct Options { // and gives MSA more flexibility in choosing the prefetch time and how much // data to prefetch. bool enable_window_prefetch = false; + + MsaSortOrderOverrides msa_sort_order_overrides; }; } // namespace memory_space_assignment } // namespace xla diff --git a/xla/service/memory_space_assignment/utils.cc b/xla/service/memory_space_assignment/utils.cc index 740fdf226d3384..3025079e499df7 100644 --- a/xla/service/memory_space_assignment/utils.cc +++ b/xla/service/memory_space_assignment/utils.cc @@ -140,6 +140,11 @@ bool MemorySpaceAssignmentUtils::DoesUseMatchFilter( filter.instruction_name_regex())) { return false; } + if (filter.has_instruction_regex() && + !RE2::FullMatch(hlo_use.instruction->ToString(), + filter.instruction_regex())) { + return false; + } return true; } @@ -168,7 +173,22 @@ bool MemorySpaceAssignmentUtils::DoesPositionMatchFilter( !RE2::FullMatch(instruction->ToString(), filter.instruction_regex())) { return false; } - return true; + return DoesBufferIntervalMatchHloUseFilter(filter, buffer_interval); +} + +bool MemorySpaceAssignmentUtils::DoesBufferIntervalMatchHloUseFilter( + const HloPositionMatcher& filter, + const MsaBufferInterval& buffer_interval) { + if (!filter.has_hlo_use_filter()) { + return true; + } + for (const HloUse& use : buffer_interval.buffer->GetUses()) { + if (DoesUseMatchFilter(filter.hlo_use_filter(), use, + buffer_interval.size)) { + return true; + } + } + return false; } absl::StatusOr @@ -273,14 +293,39 @@ MemorySpaceAssignmentUtils::GetOverriddenPreferredPrefetchTime( return static_cast>>(std::nullopt); } +bool MemorySpaceAssignmentUtils::DoesCrossProgramPrefetchBufferMatchAnyFilter( + const MsaSortOrderOverrides& sort_order_overrides, + const MsaBufferInterval& buffer_interval) { + for (const MsaSortOrderOverride& override : + sort_order_overrides.overrides()) { + if (override.has_apply_to_cross_program_prefetches() && + override.apply_to_cross_program_prefetches() && + MemorySpaceAssignmentUtils::DoesPositionMatchFilter( + override.hlo_position_matcher(), buffer_interval) && + override.override_options().has_assign_first() && + override.override_options().assign_first()) { + VLOG(3) << "Cross program prefetch buffer " + << buffer_interval.buffer->ToString() + << " matches sort order override " << override.DebugString(); + return true; + } + } + return false; +} + int64_t MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority( const MsaSortOrderOverrides& msa_sort_order_overrides, - const MsaBufferInterval& buffer_interval) { + const MsaBufferInterval& buffer_interval, bool is_cross_program_prefetch) { if (msa_sort_order_overrides.overrides_size() == 0) { return 0; } for (int64_t i = 0; i < msa_sort_order_overrides.overrides_size(); ++i) { const auto& override = msa_sort_order_overrides.overrides(i); + if (is_cross_program_prefetch && + (!override.has_apply_to_cross_program_prefetches() || + !override.apply_to_cross_program_prefetches())) { + continue; + } if (!MemorySpaceAssignmentUtils::DoesPositionMatchFilter( override.hlo_position_matcher(), buffer_interval)) { continue; diff --git a/xla/service/memory_space_assignment/utils.h b/xla/service/memory_space_assignment/utils.h index cac987731c183d..77a05a8cd26e20 100644 --- a/xla/service/memory_space_assignment/utils.h +++ b/xla/service/memory_space_assignment/utils.h @@ -91,10 +91,20 @@ class MemorySpaceAssignmentUtils { instruction_schedule, int64_t earliest_prefetch_time, int64_t latest_prefetch_time); + static bool DoesCrossProgramPrefetchBufferMatchAnyFilter( + const MsaSortOrderOverrides& sort_order_overrides, + const MsaBufferInterval& buffer_interval); + // Returns an integer representing the priority of a MsaBufferInterval during // assignment, a smaller number indicates a higher priority. static int64_t GetBufferIntervalOverridePriority( const MsaSortOrderOverrides& msa_sort_order_overrides, + const MsaBufferInterval& buffer_interval, + bool is_cross_program_prefetch = false); + + private: + static bool DoesBufferIntervalMatchHloUseFilter( + const HloPositionMatcher& filter, const MsaBufferInterval& buffer_interval); };