Skip to content

Commit

Permalink
[XLA:MSA] Fixes a bug in GetInefficientAllocationSites(allocation_val…
Browse files Browse the repository at this point in the history
…ues). The function was previously assuming allocation_values can never be empty.

PiperOrigin-RevId: 693424356
  • Loading branch information
mehrdadkhani authored and Google-ML-Automation committed Nov 19, 2024
1 parent c6f26d4 commit c751ea0
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 9 deletions.
4 changes: 2 additions & 2 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -697,15 +697,15 @@ cc_library(
],
)

# Transitional forwarding target. Use cpu:cpu_client instead.
# Transitional forwarding target. Use pjrt/plugin/xla_cpu:xla_cpu_pjrt_client instead.
cc_library(
name = "tfrt_cpu_pjrt_client",
hdrs = ["tfrt_cpu_pjrt_client.h"],
visibility = internal_visibility([
"//xla:friends",
]),
deps = [
"//xla/pjrt/cpu:cpu_client",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
],
)

Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/tfrt_cpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ limitations under the License.

// Transitional forwarding header. Please include cpu/cpu_client.h directly.

#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"

#endif // XLA_PJRT_TFRT_CPU_PJRT_CLIENT_H_
9 changes: 4 additions & 5 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2130,7 +2130,10 @@ MsaAlgorithm::GetInefficientAllocationSites(
return {};
}

int64_t size = allocation_values.at(0).size();
int64_t size = 0;
if (!allocation_values.empty()) {
size = allocation_values.at(0).size();
}

if (VLOG_IS_ON(3)) {
for (const AllocationValue& allocation_value : allocation_values) {
Expand Down Expand Up @@ -4384,10 +4387,6 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
*use.instruction, use.operand_number, use.operand_index);
}

if (request.only_extend_existing_allocation &&
!allocation_sequence->empty()) {
allocation_sequence->back()->Extend(request.inclusive_start_time);
}
// There could be a requirement to pin this buffer to default memory either
// because it is a parameter or an output. If the buffer is a parameter, then
// we're allowed to prefetch. If the use expects the output to be in default
Expand Down
34 changes: 34 additions & 0 deletions xla/service/memory_space_assignment/memory_space_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,37 @@ MemorySpaceAssignment::Run(HloModule* module,
alias_analysis);
}

absl::Status MemorySpaceAssignment::VerifyAllocations() const {
BufferIntervalTree interval_tree;
// Checks the chunks that overlap with a given allocation in time do not
// overlap with the allocation's chunk in the memory range. If they do, we
// throw an error, otherwise we add the allocation's chunk to the interval
// tree and return an OK status.
auto add_allocation_and_verify =
[&](const Allocation* allocation) -> absl::Status {
for (const HeapSimulator::Chunk& overlapping_chunk :
interval_tree.ChunksOverlappingInTime(allocation->start_time(),
allocation->end_time() - 1)) {
CHECK(!allocation->chunk().OverlapsWith(overlapping_chunk))
<< "Chunks are overlapping at Allocation level (before fixing the "
"schedule): "
<< allocation->ToString()
<< " overlaps with allocated chunk: " << overlapping_chunk.ToString();
}
interval_tree.Add(allocation->start_time(), allocation->end_time() - 1,
allocation->chunk());
return absl::OkStatus();
};
// Verify that all alternate memory allocations are free of overlapping
// Allocations in time and space, and add them to interval_tree one by one.
for (const auto& allocation : allocations_) {
if (allocation->memory_space() == MemorySpace::kAlternate) {
TF_RETURN_IF_ERROR(add_allocation_and_verify(allocation.get()));
}
}
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<PresetAssignments>>
MemorySpaceAssignment::RunMemorySpaceAssignment(
const HloLiveRange& hlo_live_range,
Expand All @@ -365,6 +396,9 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
}

TF_RETURN_IF_ERROR(Process(hlo_live_range));
if (options_.verify) {
TF_RETURN_IF_ERROR(VerifyAllocations());
}
// DEBUG_LOG_ALLOCATIONS_AT
//
// Uncomment the following to log the alternate memory allocations that MSA
Expand Down
5 changes: 5 additions & 0 deletions xla/service/memory_space_assignment/memory_space_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ class MemorySpaceAssignment {
// Calculates asynchronous copy statistics.
absl::StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;

// Verify that allocations_ are free of overlapping Allocations in time and
// space. This is a post-processing step called after all allocations have
// been finalized, before the async copies get scheduled.
absl::Status VerifyAllocations() const;

// Verify that the memory space assignment is free of overlapping buffers and
// export heap simulator trace to be used by buffer_assignment.
//
Expand Down
131 changes: 130 additions & 1 deletion xla/service/memory_space_assignment/memory_space_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/memory_space_assignment/memory_space_assignment.h"

#include <stdbool.h>

#include <algorithm>
#include <cstdint>
#include <functional>
Expand Down Expand Up @@ -162,7 +164,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase {
Options options;
options.max_size_in_bytes = 128;
options.alignment_in_bytes = 8;
options.verify = true;
options.verify = false;
options.alternate_memory_space = kAlternateMemorySpace;
options.max_outstanding_prefetches = -1;
options.max_outstanding_evictions = -1;
Expand Down Expand Up @@ -1309,6 +1311,133 @@ ENTRY entry {
EXPECT_LT(copy_done_index, negate4_index);
}

// Added for b/372277844#comment15 that was introduced when the allocation
// failed while trying to convert a sync slice to an async one, but not due to
// the conversion itself. In this case, associated buffer with the slice
// (p0_copy) is too large to fit in alternate memory. Hence, the
// allocation_values will be empty in retries, previously causing a crash in
// MsaAlgorithm::GetInefficientAllocationSites().
TEST_F(MemorySpaceAssignmentTest, SyncReplacementLargeBuffers) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[10,2,3]{2,1,0} parameter(0)
p1 = f32[10,2,3]{2,1,0} parameter(1)
p0_copy = f32[10,2,3]{2,1,0} copy(p0)
negate0 = negate(p1)
negate1 = negate(negate0)
negate2 = negate(negate1)
negate3 = negate(negate2)
negate4 = negate(negate3)
negate5 = negate(negate4)
slice = f32[1,2,3] slice(p0_copy), slice={[0:1], [0:2], [0:3]}
ROOT concat = f32[11,2,3] concatenate(negate5, slice), dimensions={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
Options options = DefaultMemorySpaceOptions();
options.max_size_in_bytes = 64;
options.max_retries = 2;
options.enable_sync_copy_replacement = false;
options.enable_sync_slice_replacement = true;
options.is_async_slice_implemented_fn =
[](const HloInstruction* instruction) { return true; };
// Force the allocation of p0_copy to fail for the concat use with only
// AllocationResult::kFailRequiresUncommit. This means that while the slice
// replacement was successful, the algorithm must retry one more time without
// sync slice conversion target, so that maybe other constraints of the
// allocation can be satisfied.
options.allocation_result_modifier_testing_fn =
[](const AllocationRequest& request, AllocationResult& result) {
if (request.allocation_value->defining_instruction()->name() ==
"p0_copy" &&
request.use->hlo_use.instruction->name() == "concat") {
result = AllocationResult::kFailRequiresUncommit;
}
};
// options.inefficient_use_to_copy_ratio must be greater than 0 and the cost
// model must be set to trigger the inefficient allocation site logic.
options.inefficient_use_to_copy_ratio = 1.0;
AssignMemorySpaceUsingCostAnalysis(module.get(), options);

HloInstruction* p0_copy = FindInstruction(module.get(), "p0_copy");
ASSERT_NE(p0_copy, nullptr);
HloInstruction* concat = FindInstruction(module.get(), "concat");
ASSERT_NE(concat, nullptr);
EXPECT_THAT(concat->operand(1), op::Slice(p0_copy));
}

// Added for b/376869021, which surfaced when we tried to convert a sync slice
// that had to extend the allocation of its operand in the alternate memory. In
// this test, we expect the slice0 operand (p0_copy) maintain a valid allocation
// in the alternate memory, until it gets transferred by the async replacement
// of slice0. We hence stress-test such validity by delaying the allocation of
// slice0 by 3 steps.
TEST_F(MemorySpaceAssignmentTest, SyncReplacementAllocationExtensionBug) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[2,2,3]{2,1,0} parameter(0)
p1 = f32[4,2,3]{2,1,0} parameter(1)
p0_copy = f32[2,2,3]{2,1,0} copy(p0)
negate0 = negate(p1)
negate1 = negate(negate0)
negate2 = negate(negate1)
p0_copy0_negate = negate(p0_copy)
copy_negate2 = copy(negate2)
slice0 = f32[1,2,3] slice(p0_copy), slice={[0:1], [0:2], [0:3]}
negate3 = negate(copy_negate2)
negate4 = negate(negate3)
negate5 = negate(negate4)
negate6 = negate(negate5)
negate7 = negate(negate6)
neg_slice0 = negate(slice0)
ROOT tuple = tuple(negate7, neg_slice0)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
Options options = DefaultMemorySpaceOptions();
options.enable_sync_copy_replacement = false;
options.enable_sync_slice_replacement = true;
options.verify = true;
options.is_async_slice_implemented_fn =
[](const HloInstruction* instruction) { return true; };
options.max_size_in_bytes = 96;
options.is_position_allowed_in_alternate_mem_fn =
[](const HloPosition& position) {
return position.instruction->name() != "p0_copy";
};
// Delay the allocation of slice0 by 3 steps to allow copy_negate2 to be
// allocated in alternate memory.
options.allocation_request_modifier_testing_fn =
[](AllocationRequest& request) {
if (request.only_extend_existing_allocation) {
request.inclusive_start_time += 3;
request.end_time += 3;
}
};
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher { instruction_name_regex: "copy_negate2|p0_copy" }
override_options { assign_first: true }
})pb";
TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));
auto preset_assignments = AssignMemorySpaceUsingCostAnalysis(
module.get(), options,
/*cost_analysis_options_override=*/std::nullopt,
/*hlo_cost_options_override=*/std::nullopt,
/*optional_msa_sort_order_overrides=*/msa_sort_order_overrides);
HloInstruction* p0_copy = FindInstruction(module.get(), "p0_copy");
ASSERT_NE(p0_copy, nullptr);
HloInstruction* neg_slice0 = FindInstruction(module.get(), "neg_slice0");
ASSERT_NE(neg_slice0, nullptr);
EXPECT_THAT(neg_slice0->operand(0), op::AsyncDone(op::AsyncStart(p0_copy)));
}

TEST_F(MemorySpaceAssignmentTest, AlwaysSpillJitPrefetchTest) {
// The negate chain is long enough for asynchronous copy to be inserted
// between p1 and add.
Expand Down

0 comments on commit c751ea0

Please sign in to comment.