@@ -47,13 +47,13 @@ limitations under the License.
47
47
#include " absl/strings/string_view.h"
48
48
#include " absl/types/span.h"
49
49
#include " xla/debug_options_flags.h"
50
- #include " xla/hlo/analysis/hlo_alias_analysis.h"
51
50
#include " xla/hlo/analysis/hlo_dataflow_analysis.h"
52
51
#include " xla/hlo/ir/hlo_computation.h"
53
52
#include " xla/hlo/ir/hlo_instruction.h"
54
53
#include " xla/hlo/ir/hlo_opcode.h"
55
54
#include " xla/hlo/ir/hlo_schedule.h"
56
55
#include " xla/hlo/utils/hlo_live_range.h"
56
+ #include " xla/layout.h"
57
57
#include " xla/service/buffer_value.h"
58
58
#include " xla/service/call_graph.h"
59
59
#include " xla/service/computation_layout.h"
@@ -74,11 +74,10 @@ limitations under the License.
74
74
#include " xla/service/memory_space_assignment/utils.h"
75
75
#include " xla/service/time_utils.h"
76
76
#include " xla/shape.h"
77
+ #include " xla/shape_tree.h"
77
78
#include " xla/shape_util.h"
78
79
#include " xla/util.h"
79
80
#include " xla/xla_data.pb.h"
80
- #include " tsl/platform/status.h"
81
- #include " tsl/platform/statusor.h"
82
81
83
82
namespace xla {
84
83
namespace memory_space_assignment {
@@ -1786,7 +1785,7 @@ absl::StatusOr<HeapSimulator::Result<HloValue>> MsaAlgorithm::Finish() {
1786
1785
}
1787
1786
}
1788
1787
VLOG (2 ) << " Total reserved bytes = " << reserved_in_bytes_;
1789
- for (auto & interval : sorted_buffer_intervals) {
1788
+ for (MsaBufferInterval & interval : sorted_buffer_intervals) {
1790
1789
if (finalized_values_.contains (interval.buffer )) {
1791
1790
VLOG (3 ) << " Skip entrance interval" << interval.buffer ->ToShortString ()
1792
1791
<< " because it is already processed." ;
@@ -2306,6 +2305,69 @@ void MsaAlgorithm::CreateAllocationValuesFromColocatedIntervals(
2306
2305
absl::c_move (new_allocation_values, std::back_inserter (allocation_values));
2307
2306
}
2308
2307
2308
+ void MsaAlgorithm::MaybeSplitAllocationValues (
2309
+ absl::Span<AllocationValue> allocation_values) {
2310
+ if (options_.determine_split_dimension_fn == nullptr ||
2311
+ options_.shape_size_fn == nullptr ||
2312
+ options_.init_split_tree_fn == nullptr ) {
2313
+ return ;
2314
+ }
2315
+
2316
+ std::vector<std::optional<SplitConfig>> results;
2317
+
2318
+ for (AllocationValue& allocation_value : allocation_values) {
2319
+ std::optional<SplitConfig> result = options_.determine_split_dimension_fn (
2320
+ *allocation_value.value (), &instruction_to_split_dims_);
2321
+ results.push_back (std::move (result));
2322
+ }
2323
+ for (int i = 0 ; i < results.size (); ++i) {
2324
+ if (results[i] != results[0 ]) {
2325
+ VLOG (3 ) << " Skipping splitting joint allocation values with different "
2326
+ " split choices: "
2327
+ << allocation_values[0 ].ToShortString () << " -> "
2328
+ << (results[0 ].has_value () ? results[0 ]->ToString () : " nullopt" )
2329
+ << " vs " << allocation_values[i].ToShortString () << " -> "
2330
+ << (results[i].has_value () ? results[i]->ToString () : " nullopt" );
2331
+ return ;
2332
+ }
2333
+ }
2334
+
2335
+ for (int i = 0 ; i < allocation_values.size (); ++i) {
2336
+ auto & allocation_value = allocation_values[i];
2337
+ HloInstruction* defining_instruction =
2338
+ allocation_value.value ()->defining_instruction ();
2339
+ auto & result = results[i];
2340
+ if (!instruction_to_split_dims_.contains (defining_instruction)) {
2341
+ instruction_to_split_dims_[allocation_value.value ()
2342
+ ->defining_instruction ()] =
2343
+ options_.init_split_tree_fn (defining_instruction, nullptr );
2344
+ }
2345
+ int64_t * mutable_element =
2346
+ instruction_to_split_dims_[defining_instruction].mutable_element (
2347
+ allocation_value.value ()->defining_position ().index );
2348
+ if (!result.has_value ()) {
2349
+ *mutable_element = options_.replicated_split_dimension ;
2350
+ VLOG (4 ) << " Splitting allocation value: "
2351
+ << allocation_value.ToShortString () << " : kReplicated." ;
2352
+ continue ;
2353
+ }
2354
+ // TODO(b/382592216): Delay this assignment until after the AllocationValue
2355
+ // actually gets an alternate memory allocation.
2356
+ *mutable_element = result->dimension ();
2357
+ Shape new_shape = allocation_value.value ()->shape ();
2358
+ if (new_shape.has_layout () &&
2359
+ new_shape.layout ().split_configs_size () == 0 ) {
2360
+ new_shape.mutable_layout ()->add_split_configs (result.value ());
2361
+ }
2362
+ allocation_value.set_split_shape (new_shape);
2363
+ int64_t shape_size = options_.shape_size_fn (new_shape);
2364
+
2365
+ VLOG (4 ) << " Splitting allocation value: "
2366
+ << allocation_value.ToShortString () << " : " << result->ToString ();
2367
+ allocation_value.set_size (shape_size);
2368
+ }
2369
+ }
2370
+
2309
2371
bool MsaAlgorithm::RequiresNoCopyAlternateMemAllocation (
2310
2372
AllocationValue& allocation_value) const {
2311
2373
return allocation_value.value ()->shape ().has_layout () &&
@@ -2425,6 +2487,8 @@ absl::StatusOr<AllocationResult> MsaAlgorithm::AllocateAllocationValues(
2425
2487
VLOG (3 ) << " all_use_times[" << i << " ] = " << all_use_times[i];
2426
2488
}
2427
2489
2490
+ MaybeSplitAllocationValues (allocation_values);
2491
+
2428
2492
// Data structure to contain the preferred offset for a given computation.
2429
2493
// We ensure that the same offset will be allocated outside the while loop
2430
2494
// as well as inside the while loop.
@@ -4240,6 +4304,10 @@ void MsaAlgorithm::FinalizeAllocations(
4240
4304
absl::flat_hash_map<const AliasedOffset*, size_t > offset_to_index;
4241
4305
for (AllocationValue& allocation_value : allocation_values) {
4242
4306
for (auto & allocation : *allocation_value.mutable_allocation_sequence ()) {
4307
+ if (allocation->memory_space () == MemorySpace::kAlternate &&
4308
+ allocation_value.mutable_split_shape ().has_value ()) {
4309
+ allocation->set_split_shape (allocation_value.mutable_split_shape ());
4310
+ }
4243
4311
if ((allocation->memory_space () == MemorySpace::kAlternate ) &&
4244
4312
(!allocation->is_scoped_allocation ())) {
4245
4313
for (const HloUse& use : allocation->uses ()) {
@@ -4902,7 +4970,7 @@ AllocationResult MsaAlgorithm::AllocateInAlternateMemoryNoCopy(
4902
4970
4903
4971
if (request.preferred_offset ) {
4904
4972
// If there is a preferred offset provided in the request and if it doesn't
4905
- // match the previous allocation, this request cannot be satisified .
4973
+ // match the previous allocation, this request cannot be satisfied .
4906
4974
if (preferred_offset && request.preferred_offset != preferred_offset) {
4907
4975
VLOG (3 ) << " Cannot perform no-copy allocation due to mismatch: "
4908
4976
" preferred_offset = "
0 commit comments