Skip to content

Commit

Permalink
[XLA:TPU:MSA]
Browse files Browse the repository at this point in the history
* Add support for overriding cross program prefetch behavior.
* Add support for filtering buffer intervals based on the uses of the buffer.
* Add tests for overriding cross program prefetch behavior
* Add tests for expanding filtering criteria.

PiperOrigin-RevId: 698574108
  • Loading branch information
subhankarshah authored and Google-ML-Automation committed Nov 21, 2024
1 parent fe35dee commit ee77f15
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 20 deletions.
20 changes: 12 additions & 8 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,24 @@ std::vector<MsaBufferInterval> FindCrossProgramPrefetchCandidates(
for (const HloBuffer& buffer : alias_analysis.buffers()) {
CHECK_GE(buffer.values().size(), 1);
const HloValue* value = buffer.values().at(0);
MsaBufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) {
MsaBufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
candidates.emplace_back(interval);
} else if (MemorySpaceAssignmentUtils::
DoesCrossProgramPrefetchBufferMatchAnyFilter(
options.msa_sort_order_overrides, interval)) {
candidates.emplace_back(interval);
}
}

DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator(
hlo_live_range);
hlo_live_range, options.msa_sort_order_overrides);
BufferIntervalComparator* comparator =
(options.default_cross_program_prefetch_heuristic &&
options.buffer_interval_comparator
Expand Down
18 changes: 13 additions & 5 deletions xla/service/memory_space_assignment/buffer_interval_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ MemoryBoundednessBufferIntervalComparator::GetTuple(

DefaultCrossProgramPrefetchBufferIntervalComparator::
DefaultCrossProgramPrefetchBufferIntervalComparator(
const HloLiveRange& hlo_live_range)
: BufferIntervalComparator(), hlo_live_range_(hlo_live_range) {}
const HloLiveRange& hlo_live_range,
const MsaSortOrderOverrides& msa_sort_order_overrides)
: BufferIntervalComparator(),
hlo_live_range_(hlo_live_range),
msa_sort_order_overrides_(msa_sort_order_overrides) {}

std::string DefaultCrossProgramPrefetchBufferIntervalComparator::
DescribeComparisonCriteria() const {
Expand All @@ -138,6 +141,10 @@ bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan(
DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple
DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
const MsaBufferInterval& buffer_interval) {
int64_t priority =
MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority(
msa_sort_order_overrides_, buffer_interval,
/*is_cross_program_prefetch=*/true);
auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer);
if (sort_data_it == additional_sort_data_.end()) {
AdditionalSortData sort_data;
Expand All @@ -155,9 +162,10 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
.first;
}

return std::make_tuple(
-1 * buffer_interval.size, -1 * sort_data_it->second.cumulative_use_size,
sort_data_it->second.latest_use, buffer_interval.buffer->id());
return std::make_tuple(priority, -1 * buffer_interval.size,
-1 * sort_data_it->second.cumulative_use_size,
sort_data_it->second.latest_use,
buffer_interval.buffer->id());
}

} // namespace memory_space_assignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator
: public BufferIntervalComparator {
public:
explicit DefaultCrossProgramPrefetchBufferIntervalComparator(
const HloLiveRange& hlo_live_range);
const HloLiveRange& hlo_live_range,
const MsaSortOrderOverrides& msa_sort_order_overrides);

~DefaultCrossProgramPrefetchBufferIntervalComparator() override = default;

Expand All @@ -128,7 +129,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator
// See the value returned by DescribeComparisonCriteria() for the meaning of
// each tuple element.
using ComparisonTuple =
std::tuple<int64_t, int64_t, int64_t, BufferValue::Id>;
std::tuple<int64_t, int64_t, int64_t, int64_t, BufferValue::Id>;

struct AdditionalSortData {
int64_t latest_use = 0;
Expand All @@ -140,6 +141,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator
absl::flat_hash_map<const HloValue*, AdditionalSortData>
additional_sort_data_;
const HloLiveRange& hlo_live_range_;
const MsaSortOrderOverrides& msa_sort_order_overrides_;
};

} // namespace memory_space_assignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ message HloOperandFilter {
// If operand of an instruction is a tuple and indexing into the tuple is
// required.
optional TupleShapeIndex tuple_index = 5;
// Regex to match the entire instruction HLO. The HLO string is constructed
// using default HloPrintOptions. Refer to the HloPrintOptions class in
// hlo_instruction.h to know more about the format of the HLO string used for
// matching.
optional string instruction_regex = 6;
}

// Options to override preferred prefetch time for an operand.
Expand Down Expand Up @@ -154,6 +159,8 @@ message HloPositionMatcher {
optional int64 size_gte = 4;
// Filters instructions with output size in bytes less or equal to a value.
optional int64 size_lte = 5;
// Filters instructions that have a use that matches the filter.
optional HloOperandFilter hlo_use_filter = 6;
}

// Options to override preferred prefetch time for an operand.
Expand All @@ -176,6 +183,7 @@ message MsaSortOrderOverride {
optional HloPositionMatcher hlo_position_matcher = 1;
optional xla.memory_space_assignment.MsaSortOrderOverrideOptions
override_options = 2;
optional bool apply_to_cross_program_prefetches = 3;
}

// Encloses chained override configs. The first config has highest precedence
Expand Down
170 changes: 167 additions & 3 deletions xla/service/memory_space_assignment/memory_space_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5595,6 +5595,77 @@ TEST_F(MemorySpaceAssignmentTest,
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_F(MemorySpaceAssignmentTest,
MemoryBoundednessOverrideSortOrderByUseAssignFirst) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[3,4]{1,0} parameter(0)
p1 = f32[3,4]{1,0} parameter(1)
tanh0 = f32[3,4]{1,0} tanh(p0)
negate0 = f32[3,4]{1,0} negate(p1)
tanh1 = f32[3,4]{1,0} tanh(tanh0)
negate1 = f32[3,4]{1,0} negate(negate0)
tanh2 = f32[3,4]{1,0} tanh(tanh1)
negate2 = f32[3,4]{1,0} negate(negate1)
tanh3 = f32[3,4]{1,0} tanh(tanh2)
negate3 = f32[3,4]{1,0} negate(negate2)
tanh4 = f32[3,4]{1,0} tanh(tanh3)
negate4 = f32[3,4]{1,0} negate(negate3)
ROOT tuple = (f32[3,4]{1,0}, f32[3,4]{1,0}) tuple(tanh4, negate4)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

// Override MSA sort order and try to assign all negates to alternate memory
// first. Alternate memory size is enough to fit 2 f32[4,3] tensors at a time.
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher {
hlo_use_filter { instruction_name_regex: "negate(.*)" }
}
override_options { assign_first: true }
})pb";
TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));

AssignMemorySpaceUsingCostAnalysis(
module.get(), /*memory_space_options_override=*/std::nullopt,
/*cost_analysis_options_override=*/std::nullopt,
/*hlo_cost_options_override=*/std::nullopt,
/*optional_msa_sort_order_overrides=*/msa_sort_order_overrides);
// Parameters are in the default memory space.
const HloInstruction* p0 = FindInstruction(module.get(), "p0");
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* p1 = FindInstruction(module.get(), "p1");
EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace);
// Check that all negates are in alternate memory space except negate4.
// negate4 is a program output, so it has to land in default memory.
HloInstruction* negate0 = FindInstruction(module.get(), "negate0");
EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate1 = FindInstruction(module.get(), "negate1");
EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate2 = FindInstruction(module.get(), "negate2");
EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate3 = FindInstruction(module.get(), "negate3");
EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate4 = FindInstruction(module.get(), "negate4");
EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0");
EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1");
EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2");
EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3");
EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4");
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_F(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
Expand Down Expand Up @@ -10399,9 +10470,12 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto preset_assignments = AssignMemorySpace(
module.get(), DefaultMemorySpaceOptions(),
/*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
auto options = DefaultMemorySpaceOptions();
// Enough space to fit the cross-program prefetch for both p0 and p1.
options.max_size_in_bytes = 512;
auto preset_assignments = AssignMemorySpace(module.get(), options,
/*max_prefetch_interval=*/5,
/*min_prefetch_interval=*/2);

auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
Expand Down Expand Up @@ -10455,6 +10529,96 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
EXPECT_TRUE(has_zero_offset_allocations);
}

TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) {
// This test is for checking if the cross-program-prefetched buffer is freed
// after its last use and there is an end-of-program prefetch.
absl::string_view hlo_string = R"(
HloModule cross_program_prefetch, is_scheduled=true

ENTRY CrossProgramPrefetch {
p0 = f32[8,8]{1,0} parameter(0)
p1 = f32[8,2]{1,0} parameter(1)
dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
negate.1 = f32[8,2]{1,0} negate(dot)
negate.2 = f32[8,2]{1,0} negate(negate.1)
negate.3 = f32[8,2]{1,0} negate(negate.2)
negate.4 = f32[8,2]{1,0} negate(negate.3)
negate.5 = f32[8,2]{1,0} negate(negate.4)
negate.6 = f32[8,2]{1,0} negate(negate.5)
negate.7 = f32[8,2]{1,0} negate(negate.6)
negate.8 = f32[8,2]{1,0} negate(negate.7)
ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto options = DefaultMemorySpaceOptions();
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher {
instruction_name_regex: "p(.*)"
instruction_regex: ".*parameter\\(0\\).*"
}
override_options { assign_first: true }
apply_to_cross_program_prefetches: true
})pb";
TF_ASSERT_OK_AND_ASSIGN(options.msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));
options.max_size_in_bytes = 256;
auto preset_assignments = AssignMemorySpace(module.get(), options,
/*max_prefetch_interval=*/5,
/*min_prefetch_interval=*/2);

auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
EXPECT_EQ(cross_program_prefetches[0].parameter, 0);
EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({}));

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
HloDataflowAnalysis::Run(*module));
LOG(ERROR) << "module: " << module->ToString();
const HloValue& cross_program_prefetched_value =
dataflow_analysis->GetValueDefinedAt(
module->entry_computation()->parameter_instruction(0), {});
// Expect that there are two prefetches that use this value, one is the
// cross-program prefetch, the other is the end-of-program prefetch.
auto is_cross_program_prefetch = [](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
use.instruction->cross_program_prefetch_index().has_value();
};
EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
is_cross_program_prefetch),
1);
auto is_end_of_program_prefetch = [](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
!use.instruction->cross_program_prefetch_index().has_value();
};
EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
is_end_of_program_prefetch),
1);
// Also verify that the copy-done for the end-of-program prefetch is the last
// instruction in schedule.
const HloInstruction* last_instruction =
module->schedule()
.sequence(module->entry_computation())
.instructions()[module->entry_computation()->instruction_count() - 1];
EXPECT_THAT(last_instruction, op::CopyDone());
EXPECT_NE(last_instruction, module->entry_computation()->root_instruction());
// Cross program prefetch would use offset 0 because that's the first
// assignment. Since we are freeing the cross-program prefetch buffer, we
// would also expect to see some of the intermediate computations (one of the
// negate ops) to also get 0 offset allocations.
bool has_zero_offset_allocations = false;
for (auto pos_and_chunk : preset_assignments->chunks()) {
if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate &&
pos_and_chunk.second.offset == 0) {
has_zero_offset_allocations = true;
}
}
EXPECT_TRUE(has_zero_offset_allocations);
}

TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) {
// This test is for checking if the cross-program-prefetched buffer is freed
// after its last use and there is an end-of-program prefetch.
Expand Down
2 changes: 2 additions & 0 deletions xla/service/memory_space_assignment/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ struct Options {
// and gives MSA more flexibility in choosing the prefetch time and how much
// data to prefetch.
bool enable_window_prefetch = false;

MsaSortOrderOverrides msa_sort_order_overrides;
};
} // namespace memory_space_assignment
} // namespace xla
Expand Down
Loading

0 comments on commit ee77f15

Please sign in to comment.