Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:TPU:MSA] #19491

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,20 +282,24 @@ std::vector<MsaBufferInterval> FindCrossProgramPrefetchCandidates(
for (const HloBuffer& buffer : alias_analysis.buffers()) {
CHECK_GE(buffer.values().size(), 1);
const HloValue* value = buffer.values().at(0);
MsaBufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
if (IsCrossProgramPrefetchCandidate(*value, alias_analysis, options)) {
MsaBufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
interval.colocations = {++buffer.values().begin(), buffer.values().end()};
candidates.emplace_back(interval);
} else if (MemorySpaceAssignmentUtils::
DoesCrossProgramPrefetchBufferMatchAnyFilter(
options.msa_sort_order_overrides, interval)) {
candidates.emplace_back(interval);
}
}

DefaultCrossProgramPrefetchBufferIntervalComparator default_comparator(
hlo_live_range);
hlo_live_range, options.msa_sort_order_overrides);
BufferIntervalComparator* comparator =
(options.default_cross_program_prefetch_heuristic &&
options.buffer_interval_comparator
Expand Down
18 changes: 13 additions & 5 deletions xla/service/memory_space_assignment/buffer_interval_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,11 @@ MemoryBoundednessBufferIntervalComparator::GetTuple(

DefaultCrossProgramPrefetchBufferIntervalComparator::
DefaultCrossProgramPrefetchBufferIntervalComparator(
const HloLiveRange& hlo_live_range)
: BufferIntervalComparator(), hlo_live_range_(hlo_live_range) {}
const HloLiveRange& hlo_live_range,
const MsaSortOrderOverrides& msa_sort_order_overrides)
: BufferIntervalComparator(),
hlo_live_range_(hlo_live_range),
msa_sort_order_overrides_(msa_sort_order_overrides) {}

std::string DefaultCrossProgramPrefetchBufferIntervalComparator::
DescribeComparisonCriteria() const {
Expand All @@ -138,6 +141,10 @@ bool DefaultCrossProgramPrefetchBufferIntervalComparator::LessThan(
DefaultCrossProgramPrefetchBufferIntervalComparator::ComparisonTuple
DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
const MsaBufferInterval& buffer_interval) {
int64_t priority =
MemorySpaceAssignmentUtils::GetBufferIntervalOverridePriority(
msa_sort_order_overrides_, buffer_interval,
/*is_cross_program_prefetch=*/true);
auto sort_data_it = additional_sort_data_.find(buffer_interval.buffer);
if (sort_data_it == additional_sort_data_.end()) {
AdditionalSortData sort_data;
Expand All @@ -155,9 +162,10 @@ DefaultCrossProgramPrefetchBufferIntervalComparator::GetTuple(
.first;
}

return std::make_tuple(
-1 * buffer_interval.size, -1 * sort_data_it->second.cumulative_use_size,
sort_data_it->second.latest_use, buffer_interval.buffer->id());
return std::make_tuple(priority, -1 * buffer_interval.size,
-1 * sort_data_it->second.cumulative_use_size,
sort_data_it->second.latest_use,
buffer_interval.buffer->id());
}

} // namespace memory_space_assignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator
: public BufferIntervalComparator {
public:
explicit DefaultCrossProgramPrefetchBufferIntervalComparator(
const HloLiveRange& hlo_live_range);
const HloLiveRange& hlo_live_range,
const MsaSortOrderOverrides& msa_sort_order_overrides);

~DefaultCrossProgramPrefetchBufferIntervalComparator() override = default;

Expand All @@ -128,7 +129,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator
// See the value returned by DescribeComparisonCriteria() for the meaning of
// each tuple element.
using ComparisonTuple =
std::tuple<int64_t, int64_t, int64_t, BufferValue::Id>;
std::tuple<int64_t, int64_t, int64_t, int64_t, BufferValue::Id>;

struct AdditionalSortData {
int64_t latest_use = 0;
Expand All @@ -140,6 +141,7 @@ class DefaultCrossProgramPrefetchBufferIntervalComparator
absl::flat_hash_map<const HloValue*, AdditionalSortData>
additional_sort_data_;
const HloLiveRange& hlo_live_range_;
const MsaSortOrderOverrides& msa_sort_order_overrides_;
};

} // namespace memory_space_assignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ message HloOperandFilter {
// If operand of an instruction is a tuple and indexing into the tuple is
// required.
optional TupleShapeIndex tuple_index = 5;
// Regex to match the entire instruction HLO. The HLO string is constructed
// using default HloPrintOptions. Refer to the HloPrintOptions class in
// hlo_instruction.h to know more about the format of the HLO string used for
// matching.
optional string instruction_regex = 6;
}

// Options to override preferred prefetch time for an operand.
Expand Down Expand Up @@ -154,6 +159,8 @@ message HloPositionMatcher {
optional int64 size_gte = 4;
// Filters instructions with output size in bytes less or equal to a value.
optional int64 size_lte = 5;
// Filters instructions that have a use that matches the filter.
optional HloOperandFilter hlo_use_filter = 6;
}

// Options to override preferred prefetch time for an operand.
Expand All @@ -176,6 +183,7 @@ message MsaSortOrderOverride {
optional HloPositionMatcher hlo_position_matcher = 1;
optional xla.memory_space_assignment.MsaSortOrderOverrideOptions
override_options = 2;
optional bool apply_to_cross_program_prefetches = 3;
}

// Encloses chained override configs. The first config has highest precedence
Expand Down
170 changes: 167 additions & 3 deletions xla/service/memory_space_assignment/memory_space_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5595,6 +5595,77 @@ TEST_F(MemorySpaceAssignmentTest,
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_F(MemorySpaceAssignmentTest,
MemoryBoundednessOverrideSortOrderByUseAssignFirst) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[3,4]{1,0} parameter(0)
p1 = f32[3,4]{1,0} parameter(1)
tanh0 = f32[3,4]{1,0} tanh(p0)
negate0 = f32[3,4]{1,0} negate(p1)
tanh1 = f32[3,4]{1,0} tanh(tanh0)
negate1 = f32[3,4]{1,0} negate(negate0)
tanh2 = f32[3,4]{1,0} tanh(tanh1)
negate2 = f32[3,4]{1,0} negate(negate1)
tanh3 = f32[3,4]{1,0} tanh(tanh2)
negate3 = f32[3,4]{1,0} negate(negate2)
tanh4 = f32[3,4]{1,0} tanh(tanh3)
negate4 = f32[3,4]{1,0} negate(negate3)
ROOT tuple = (f32[3,4]{1,0}, f32[3,4]{1,0}) tuple(tanh4, negate4)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

// Override MSA sort order and try to assign all negates to alternate memory
// first. Alternate memory size is enough to fit 2 f32[4,3] tensors at a time.
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher {
hlo_use_filter { instruction_name_regex: "negate(.*)" }
}
override_options { assign_first: true }
})pb";
TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));

AssignMemorySpaceUsingCostAnalysis(
module.get(), /*memory_space_options_override=*/std::nullopt,
/*cost_analysis_options_override=*/std::nullopt,
/*hlo_cost_options_override=*/std::nullopt,
/*optional_msa_sort_order_overrides=*/msa_sort_order_overrides);
// Parameters are in the default memory space.
const HloInstruction* p0 = FindInstruction(module.get(), "p0");
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* p1 = FindInstruction(module.get(), "p1");
EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace);
// Check that all negates are in alternate memory space except negate4.
// negate4 is a program output, so it has to land in default memory.
HloInstruction* negate0 = FindInstruction(module.get(), "negate0");
EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate1 = FindInstruction(module.get(), "negate1");
EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate2 = FindInstruction(module.get(), "negate2");
EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate3 = FindInstruction(module.get(), "negate3");
EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate4 = FindInstruction(module.get(), "negate4");
EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0");
EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1");
EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2");
EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3");
EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4");
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_F(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
Expand Down Expand Up @@ -10399,9 +10470,12 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto preset_assignments = AssignMemorySpace(
module.get(), DefaultMemorySpaceOptions(),
/*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
auto options = DefaultMemorySpaceOptions();
// Enough space to fit the cross-program prefetch for both p0 and p1.
options.max_size_in_bytes = 512;
auto preset_assignments = AssignMemorySpace(module.get(), options,
/*max_prefetch_interval=*/5,
/*min_prefetch_interval=*/2);

auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
Expand Down Expand Up @@ -10455,6 +10529,96 @@ TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
EXPECT_TRUE(has_zero_offset_allocations);
}

TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchWithOverrideNoReuse) {
// This test is for checking if the cross-program-prefetched buffer is freed
// after its last use and there is an end-of-program prefetch.
absl::string_view hlo_string = R"(
HloModule cross_program_prefetch, is_scheduled=true

ENTRY CrossProgramPrefetch {
p0 = f32[8,8]{1,0} parameter(0)
p1 = f32[8,2]{1,0} parameter(1)
dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
negate.1 = f32[8,2]{1,0} negate(dot)
negate.2 = f32[8,2]{1,0} negate(negate.1)
negate.3 = f32[8,2]{1,0} negate(negate.2)
negate.4 = f32[8,2]{1,0} negate(negate.3)
negate.5 = f32[8,2]{1,0} negate(negate.4)
negate.6 = f32[8,2]{1,0} negate(negate.5)
negate.7 = f32[8,2]{1,0} negate(negate.6)
negate.8 = f32[8,2]{1,0} negate(negate.7)
ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto options = DefaultMemorySpaceOptions();
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher {
instruction_name_regex: "p(.*)"
instruction_regex: ".*parameter\\(0\\).*"
}
override_options { assign_first: true }
apply_to_cross_program_prefetches: true
})pb";
TF_ASSERT_OK_AND_ASSIGN(options.msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));
options.max_size_in_bytes = 256;
auto preset_assignments = AssignMemorySpace(module.get(), options,
/*max_prefetch_interval=*/5,
/*min_prefetch_interval=*/2);

auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
EXPECT_EQ(cross_program_prefetches[0].parameter, 0);
EXPECT_EQ(cross_program_prefetches[0].index, ShapeIndex({}));

TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
HloDataflowAnalysis::Run(*module));
LOG(ERROR) << "module: " << module->ToString();
const HloValue& cross_program_prefetched_value =
dataflow_analysis->GetValueDefinedAt(
module->entry_computation()->parameter_instruction(0), {});
// Expect that there are two prefetches that use this value, one is the
// cross-program prefetch, the other is the end-of-program prefetch.
auto is_cross_program_prefetch = [](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
use.instruction->cross_program_prefetch_index().has_value();
};
EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
is_cross_program_prefetch),
1);
auto is_end_of_program_prefetch = [](const HloUse& use) {
return use.instruction->opcode() == HloOpcode::kCopyStart &&
!use.instruction->cross_program_prefetch_index().has_value();
};
EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
is_end_of_program_prefetch),
1);
// Also verify that the copy-done for the end-of-program prefetch is the last
// instruction in schedule.
const HloInstruction* last_instruction =
module->schedule()
.sequence(module->entry_computation())
.instructions()[module->entry_computation()->instruction_count() - 1];
EXPECT_THAT(last_instruction, op::CopyDone());
EXPECT_NE(last_instruction, module->entry_computation()->root_instruction());
// Cross program prefetch would use offset 0 because that's the first
// assignment. Since we are freeing the cross-program prefetch buffer, we
// would also expect to see some of the intermediate computations (one of the
// negate ops) to also get 0 offset allocations.
bool has_zero_offset_allocations = false;
for (auto pos_and_chunk : preset_assignments->chunks()) {
if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate &&
pos_and_chunk.second.offset == 0) {
has_zero_offset_allocations = true;
}
}
EXPECT_TRUE(has_zero_offset_allocations);
}

TEST_F(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) {
// This test is for checking if the cross-program-prefetched buffer is freed
// after its last use and there is an end-of-program prefetch.
Expand Down
2 changes: 2 additions & 0 deletions xla/service/memory_space_assignment/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ struct Options {
// and gives MSA more flexibility in choosing the prefetch time and how much
// data to prefetch.
bool enable_window_prefetch = false;

MsaSortOrderOverrides msa_sort_order_overrides;
};
} // namespace memory_space_assignment
} // namespace xla
Expand Down
Loading
Loading