Skip to content

Commit fba3623

Browse files
Introduce shape splitting into MSA.
PiperOrigin-RevId: 688294712
1 parent 397c3d1 commit fba3623

12 files changed

+253
-17
lines changed

xla/service/memory_space_assignment/BUILD

+6
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ xla_cc_test(
9393
":utils",
9494
"//xla:comparison_util",
9595
"//xla:literal_util",
96+
"//xla:shape_tree",
9697
"//xla:shape_util",
9798
"//xla:util",
9899
"//xla:xla_data_proto_cc",
@@ -170,6 +171,7 @@ cc_library(
170171
":memory_space_assignment",
171172
":options",
172173
":prefetch_interval_picker",
174+
":utils",
173175
"//xla:shape_util",
174176
"//xla/hlo/analysis:hlo_alias_analysis",
175177
"//xla/hlo/ir:hlo",
@@ -317,6 +319,7 @@ cc_library(
317319
":prefetch_interval_picker",
318320
":repacking",
319321
":slice",
322+
"//xla:shape_tree",
320323
"//xla:shape_util",
321324
"//xla:util",
322325
"//xla/hlo/ir:hlo",
@@ -558,6 +561,7 @@ cc_library(
558561
":tuning_utils",
559562
":utils",
560563
"//xla:debug_options_flags",
564+
"//xla:shape_tree",
561565
"//xla:shape_util",
562566
"//xla:util",
563567
"//xla:xla_data_proto_cc",
@@ -621,6 +625,8 @@ cc_library(
621625
hdrs = ["allocation_value.h"],
622626
deps = [
623627
":allocation",
628+
"//xla:shape_tree",
629+
"//xla:shape_util",
624630
"//xla/hlo/ir:hlo",
625631
"//xla/service:hlo_value",
626632
"@com_google_absl//absl/container:flat_hash_set",

xla/service/memory_space_assignment/algorithm.cc

+73-5
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ limitations under the License.
4747
#include "absl/strings/string_view.h"
4848
#include "absl/types/span.h"
4949
#include "xla/debug_options_flags.h"
50-
#include "xla/hlo/analysis/hlo_alias_analysis.h"
5150
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
5251
#include "xla/hlo/ir/hlo_computation.h"
5352
#include "xla/hlo/ir/hlo_instruction.h"
5453
#include "xla/hlo/ir/hlo_opcode.h"
5554
#include "xla/hlo/ir/hlo_schedule.h"
5655
#include "xla/hlo/utils/hlo_live_range.h"
56+
#include "xla/layout.h"
5757
#include "xla/service/buffer_value.h"
5858
#include "xla/service/call_graph.h"
5959
#include "xla/service/computation_layout.h"
@@ -74,11 +74,10 @@ limitations under the License.
7474
#include "xla/service/memory_space_assignment/utils.h"
7575
#include "xla/service/time_utils.h"
7676
#include "xla/shape.h"
77+
#include "xla/shape_tree.h"
7778
#include "xla/shape_util.h"
7879
#include "xla/util.h"
7980
#include "xla/xla_data.pb.h"
80-
#include "tsl/platform/status.h"
81-
#include "tsl/platform/statusor.h"
8281

8382
namespace xla {
8483
namespace memory_space_assignment {
@@ -1786,7 +1785,7 @@ absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
17861785
}
17871786
}
17881787
VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1789-
for (auto& interval : sorted_buffer_intervals) {
1788+
for (MsaBufferInterval& interval : sorted_buffer_intervals) {
17901789
if (finalized_values_.contains(interval.buffer)) {
17911790
VLOG(3) << "Skip entrance interval" << interval.buffer->ToShortString()
17921791
<< " because it is already processed.";
@@ -2306,6 +2305,69 @@ void MsaAlgorithm::CreateAllocationValuesFromColocatedIntervals(
23062305
absl::c_move(new_allocation_values, std::back_inserter(allocation_values));
23072306
}
23082307

2308+
void MsaAlgorithm::MaybeSplitAllocationValues(
2309+
absl::Span<AllocationValue> allocation_values) {
2310+
if (options_.determine_split_dimension_fn == nullptr ||
2311+
options_.shape_size_fn == nullptr ||
2312+
options_.init_split_tree_fn == nullptr) {
2313+
return;
2314+
}
2315+
2316+
std::vector<std::optional<SplitConfig>> results;
2317+
2318+
for (AllocationValue& allocation_value : allocation_values) {
2319+
std::optional<SplitConfig> result = options_.determine_split_dimension_fn(
2320+
*allocation_value.value(), &instruction_to_split_dims_);
2321+
results.push_back(std::move(result));
2322+
}
2323+
for (int i = 0; i < results.size(); ++i) {
2324+
if (results[i] != results[0]) {
2325+
VLOG(3) << "Skipping splitting joint allocation values with different "
2326+
"split choices: "
2327+
<< allocation_values[0].ToShortString() << " -> "
2328+
<< (results[0].has_value() ? results[0]->ToString() : "nullopt")
2329+
<< " vs " << allocation_values[i].ToShortString() << " -> "
2330+
<< (results[i].has_value() ? results[i]->ToString() : "nullopt");
2331+
return;
2332+
}
2333+
}
2334+
2335+
for (int i = 0; i < allocation_values.size(); ++i) {
2336+
auto& allocation_value = allocation_values[i];
2337+
HloInstruction* defining_instruction =
2338+
allocation_value.value()->defining_instruction();
2339+
auto& result = results[i];
2340+
if (!instruction_to_split_dims_.contains(defining_instruction)) {
2341+
instruction_to_split_dims_[allocation_value.value()
2342+
->defining_instruction()] =
2343+
options_.init_split_tree_fn(defining_instruction, nullptr);
2344+
}
2345+
int64_t* mutable_element =
2346+
instruction_to_split_dims_[defining_instruction].mutable_element(
2347+
allocation_value.value()->defining_position().index);
2348+
if (!result.has_value()) {
2349+
*mutable_element = options_.replicated_split_dimension;
2350+
VLOG(4) << "Splitting allocation value: "
2351+
<< allocation_value.ToShortString() << ": kReplicated.";
2352+
continue;
2353+
}
2354+
// TODO(b/382592216): Delay this assignment until after the AllocationValue
2355+
// actually gets an alternate memory allocation.
2356+
*mutable_element = result->dimension();
2357+
Shape new_shape = allocation_value.value()->shape();
2358+
if (new_shape.has_layout() &&
2359+
new_shape.layout().split_configs_size() == 0) {
2360+
new_shape.mutable_layout()->add_split_configs(result.value());
2361+
}
2362+
allocation_value.set_split_shape(new_shape);
2363+
int64_t shape_size = options_.shape_size_fn(new_shape);
2364+
2365+
VLOG(4) << "Splitting allocation value: "
2366+
<< allocation_value.ToShortString() << ": " << result->ToString();
2367+
allocation_value.set_size(shape_size);
2368+
}
2369+
}
2370+
23092371
bool MsaAlgorithm::RequiresNoCopyAlternateMemAllocation(
23102372
AllocationValue& allocation_value) const {
23112373
return allocation_value.value()->shape().has_layout() &&
@@ -2425,6 +2487,8 @@ absl::StatusOr<AllocationResult> MsaAlgorithm::AllocateAllocationValues(
24252487
VLOG(3) << "all_use_times[" << i << "] = " << all_use_times[i];
24262488
}
24272489

2490+
MaybeSplitAllocationValues(allocation_values);
2491+
24282492
// Data structure to contain the preferred offset for a given computation.
24292493
// We ensure that the same offset will be allocated outside the while loop
24302494
// as well as inside the while loop.
@@ -4240,6 +4304,10 @@ void MsaAlgorithm::FinalizeAllocations(
42404304
absl::flat_hash_map<const AliasedOffset*, size_t> offset_to_index;
42414305
for (AllocationValue& allocation_value : allocation_values) {
42424306
for (auto& allocation : *allocation_value.mutable_allocation_sequence()) {
4307+
if (allocation->memory_space() == MemorySpace::kAlternate &&
4308+
allocation_value.mutable_split_shape().has_value()) {
4309+
allocation->set_split_shape(allocation_value.mutable_split_shape());
4310+
}
42434311
if ((allocation->memory_space() == MemorySpace::kAlternate) &&
42444312
(!allocation->is_scoped_allocation())) {
42454313
for (const HloUse& use : allocation->uses()) {
@@ -4902,7 +4970,7 @@ AllocationResult MsaAlgorithm::AllocateInAlternateMemoryNoCopy(
49024970

49034971
if (request.preferred_offset) {
49044972
// If there is a preferred offset provided in the request and if it doesn't
4905-
// match the previous allocation, this request cannot be satisified.
4973+
// match the previous allocation, this request cannot be satisfied.
49064974
if (preferred_offset && request.preferred_offset != preferred_offset) {
49074975
VLOG(3) << "Cannot perform no-copy allocation due to mismatch: "
49084976
"preferred_offset = "

xla/service/memory_space_assignment/algorithm.h

+9
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ limitations under the License.
5454
#include "xla/service/memory_space_assignment/options.h"
5555
#include "xla/service/memory_space_assignment/slice.h"
5656
#include "xla/shape.h"
57+
#include "xla/shape_tree.h"
5758
#include "xla/shape_util.h"
5859
#include "xla/util.h"
5960

@@ -999,6 +1000,11 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
9991000
ShapeIndex producer_shape_index,
10001001
absl::string_view consumer_name) const;
10011002

1003+
// Takes a group of allocation values and splits them if they can be split on
1004+
// the same dimension.
1005+
void MaybeSplitAllocationValues(
1006+
absl::Span<AllocationValue> allocation_values);
1007+
10021008
AllocationSequence* allocations_;
10031009
const Options& options_;
10041010
const HloAliasAnalysis& alias_analysis_;
@@ -1079,6 +1085,9 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap<HloValue> {
10791085
failed_async_conversions_;
10801086
absl::flat_hash_set<const HloInstruction*> successful_async_conversion_set_;
10811087
std::vector<const HloInstruction*> not_finalized_async_conversions_;
1088+
// Maps from an HloValue to the dimension it is split on.
1089+
absl::flat_hash_map<const HloInstruction*, ShapeTree<int64_t>>
1090+
instruction_to_split_dims_;
10821091
// Debug strings.
10831092
std::string buffer_info_str_;
10841093
std::string allocation_info_str_;

xla/service/memory_space_assignment/allocation.cc

+26-6
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ absl::Status Allocation::UpdateUses(HloComputation* computation,
176176
producing_instruction,
177177
use.instruction->mutable_operand(use.operand_number),
178178
use.operand_index));
179-
} else if (operand_shape != producing_instruction->shape()) {
179+
} else if (!Shape::Equal().IgnoreSplitConfigInLayout()(
180+
operand_shape, producing_instruction->shape())) {
180181
// When processing allocations, we treat bitcasts as trivial positions and
181182
// do not create allocations for them. We insert bitcasts after copies, to
182183
// account for the fact that we don't have an allocation for the bitcast.
@@ -217,7 +218,8 @@ Allocation::Allocation(HloPosition defining_position, MemorySpace memory_space,
217218
start_time_(start_time),
218219
end_time_(end_time),
219220
is_scoped_allocation_(is_scoped_allocation),
220-
cross_program_prefetch_index_(cross_program_prefetch_index) {
221+
cross_program_prefetch_index_(cross_program_prefetch_index),
222+
split_shape_(std::nullopt) {
221223
CHECK(!is_scoped_allocation ||
222224
original_defining_position_.index == ShapeIndex({}));
223225
}
@@ -278,6 +280,13 @@ absl::Status PinnedAllocation::Process() {
278280
}
279281
HloInstruction* producing_instruction = AddGetTupleElements();
280282
HloComputation* computation = producing_instruction->parent();
283+
284+
if (memory_space() == MemorySpace::kAlternate &&
285+
mutable_split_shape().has_value()) {
286+
CHECK(Shape::Equal().IgnoreSplitConfigInLayout()(
287+
producing_instruction->shape(), mutable_split_shape().value()));
288+
*producing_instruction->mutable_shape() = mutable_split_shape().value();
289+
}
281290
return UpdateUses(computation, producing_instruction);
282291
}
283292

@@ -331,6 +340,10 @@ int64_t CopyAllocation::earliest_available_time() const {
331340
absl::Status CopyAllocation::Process() {
332341
// Copy allocations need to insert asynchronous copy nodes.
333342
Shape shape = defining_position().shape();
343+
if (memory_space() == MemorySpace::kAlternate && sync_mem_op_ != nullptr &&
344+
mutable_split_shape().has_value()) {
345+
*sync_mem_op_->mutable_shape() = mutable_split_shape().value();
346+
}
334347
HloInstruction* producing_instruction = AddGetTupleElements();
335348
HloComputation* computation = producing_instruction->parent();
336349
if (sync_mem_op_ != nullptr && sync_mem_op_->opcode() != HloOpcode::kCopy) {
@@ -351,12 +364,19 @@ absl::Status CopyAllocation::Process() {
351364
TF_RETURN_IF_ERROR(
352365
copy_start_->ReplaceOperandWith(0, producing_instruction));
353366
} else {
367+
Shape dest_shape;
368+
if (memory_space() == MemorySpace::kAlternate &&
369+
mutable_split_shape().has_value()) {
370+
dest_shape = mutable_split_shape().value();
371+
} else {
372+
dest_shape = shape;
373+
}
354374
copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
355375
ShapeUtil::MakeTupleShape(
356-
{shape, shape, ShapeUtil::MakeShape(U32, {})}),
376+
{dest_shape, shape, ShapeUtil::MakeShape(U32, {})}),
357377
producing_instruction, cross_program_prefetch_index()));
358-
copy_done_ = computation->AddInstruction(
359-
HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
378+
copy_done_ = computation->AddInstruction(HloInstruction::CreateUnary(
379+
dest_shape, HloOpcode::kCopyDone, copy_start_));
360380
}
361381
VLOG(4) << "Created " << copy_start_->name()
362382
<< " for copy allocation: " << ToString();
@@ -527,7 +547,7 @@ absl::Status SlicedCopyAllocation::Process() {
527547

528548
// If we bitcast to an array of bytes above, the result of the concatenated
529549
// slices will also be an array of bytes. Thus, we need to cast the
530-
// concatentation back to the original shape.
550+
// concatenation back to the original shape.
531551
if (IsUniformSliceSizingEnabled(sliced_prefetch_options_)) {
532552
concat_ = concat_->parent()->AddInstruction(
533553
HloInstruction::CreateBitcast(shape, concat_));

xla/service/memory_space_assignment/allocation.h

+7
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class Allocation {
9090
// Returns the cross-program prefetch index for this allocation.
9191
std::optional<int64_t> cross_program_prefetch_index() const;
9292

93+
void set_split_shape(const std::optional<Shape>& split_shape) {
94+
split_shape_ = split_shape;
95+
}
96+
std::optional<Shape> mutable_split_shape() { return split_shape_; }
97+
9398
// Allocation timing methods
9499
// --------------------------------------------------------------------------
95100
// TODO(cl/604356742): update all timing methods to explicitly state that
@@ -191,6 +196,8 @@ class Allocation {
191196
const bool is_scoped_allocation_;
192197
std::vector<HloUse> uses_;
193198
std::optional<int64_t> cross_program_prefetch_index_;
199+
// If present, indicates the newly split shape.
200+
std::optional<Shape> split_shape_;
194201
};
195202

196203
using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;

xla/service/memory_space_assignment/allocation_value.h

+10-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "xla/hlo/ir/hlo_instruction.h"
2727
#include "xla/service/hlo_value.h"
2828
#include "xla/service/memory_space_assignment/allocation.h"
29+
#include "xla/shape.h"
2930

3031
namespace xla {
3132
namespace memory_space_assignment {
@@ -129,13 +130,15 @@ class AllocationValue {
129130
: value_(value),
130131
defining_position_(position),
131132
size_(size),
132-
requires_contiguous_allocation_(false) {}
133+
requires_contiguous_allocation_(false),
134+
split_shape_(std::nullopt) {}
133135

134136
const HloPosition& defining_position() const { return defining_position_; }
135137
const HloInstruction* defining_instruction() const {
136138
return defining_position().instruction;
137139
}
138140
int64_t size() const { return size_; }
141+
void set_size(int64_t size) { size_ = size; }
139142
const std::vector<Use>& uses() const { return uses_; }
140143
std::vector<Use>& uses() { return uses_; }
141144
const HloValue* value() const { return value_; }
@@ -162,6 +165,9 @@ class AllocationValue {
162165
uses_.push_back({use, use_time, {}});
163166
}
164167

168+
void set_split_shape(const Shape& split_shape) { split_shape_ = split_shape; }
169+
std::optional<Shape> mutable_split_shape() { return split_shape_; }
170+
165171
std::string ToString() const;
166172
std::string ToShortString() const;
167173

@@ -174,6 +180,9 @@ class AllocationValue {
174180
bool requires_contiguous_allocation_;
175181
std::vector<Use> uses_;
176182
AllocationSequence allocation_sequence_;
183+
184+
// If present, indicates the newly split shape.
185+
std::optional<Shape> split_shape_;
177186
};
178187

179188
// A data structure we use to associate Allocation objects that are aliased

xla/service/memory_space_assignment/memory_space_assignment.cc

+10-1
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,16 @@ absl::StatusOr<std::unique_ptr<PresetAssignments>>
380380
MemorySpaceAssignment::RunMemorySpaceAssignment(
381381
const HloLiveRange& hlo_live_range,
382382
const HloAliasAnalysis& alias_analysis) {
383+
bool splitting_enabled = options_.determine_split_dimension_fn != nullptr &&
384+
options_.init_split_tree_fn != nullptr &&
385+
options_.shape_size_fn != nullptr;
386+
if (splitting_enabled) {
387+
CHECK_EQ(options_.sliced_prefetch_options.max_slices(), 0)
388+
<< "TODO(b/167392593): Support sliced prefetches for split shapes.";
389+
CHECK(!options_.enable_window_prefetch)
390+
<< "TODO(b/167392593): Support split shapes for window "
391+
"prefetches.";
392+
}
383393
TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
384394

385395
std::optional<RuntimeSimulator> runtime_simulator = std::nullopt;
@@ -503,7 +513,6 @@ absl::Status MemorySpaceAssignment::Process(
503513
CHECK(
504514
!sliced_copy_allocation.cross_program_prefetch_index().has_value());
505515
}
506-
507516
alternate_memory_assignments_.emplace_back(
508517
allocation->defining_position(), allocation->chunk());
509518
alternate_memory_size_ =

0 commit comments

Comments
 (0)