Skip to content

Commit

Permalink
PR #19026: [NVIDIA GPU] LHS enhancement for multiple collective resou…
Browse files Browse the repository at this point in the history
…rces

Imported from GitHub PR #19026

With #17749, we can let LHS schedule for multiple collective resources. There are some cases that two collectives cannot be overlapped. When two collectives on different stream share at least 2 ranks, they can form cyclic dependency because the execution order of NCCL kernels can be different on each rank. This PR refactored LHS to expose the comparator to backend, and enforced above constraint for GPU backend.
Copybara import of the project:

--
14362ea by Terry Sun <[email protected]>:

LHS deadlock avoidance

--
e027794 by Terry Sun <[email protected]>:

minor fix

--
430db3f by Terry Sun <[email protected]>:

address nit

Merging this change closes #19026

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19026 from terryysun:terryysun/overlapping_collectives 430db3f
PiperOrigin-RevId: 696020313
  • Loading branch information
terryysun authored and Google-ML-Automation committed Nov 18, 2024
1 parent 3844060 commit 1a2606f
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 34 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
return GetSizeOfShape(shape, pointer_size);
};
auto scheduler_core = std::make_unique<DefaultSchedulerCore>(
shape_size_in_bytes, async_tracker.get(), latency_estimator.get(),
config);
shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config,
GpuScheduleCrossesOverlapLimit);
pipeline.AddPass<SchedulingInstructionAnnotator>();
pipeline.AddPass<LatencyHidingScheduler>(
std::move(latency_estimator), std::move(async_tracker),
Expand Down
85 changes: 85 additions & 0 deletions xla/service/gpu/gpu_latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) {
return IsGpuAsyncStart(from) && IsGpuAsyncDone(target);
}

// Count the maximum overlapping count in subgroups of group and other
size_t CountOverlappingRanks(const std::vector<ReplicaGroup>& group,
const std::vector<ReplicaGroup>& other) {
size_t overlapping_count = 0;
for (const auto& curr_replica_group : group) {
absl::flat_hash_set<int> curr_replica_ids;
for (const auto curr_replica_id : curr_replica_group.replica_ids()) {
curr_replica_ids.insert(curr_replica_id);
}

for (const auto& replica_group : other) {
size_t subgroup_count = 0;
for (const auto replica_id : replica_group.replica_ids()) {
if (curr_replica_ids.contains(replica_id)) ++subgroup_count;
}
overlapping_count = std::max(overlapping_count, subgroup_count);
}
}
return overlapping_count;
}

} // namespace

int64_t GetSizeOfShape(const Shape& shape, int pointer_size) {
Expand Down Expand Up @@ -141,6 +162,70 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) {
}
}

bool GpuScheduleCrossesOverlapLimit(
const DefaultSchedulerCore::SchedulingState& sched_state,
const HloGraphNode* node) {
for (const auto& [resource, limit] : sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.size() == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}

if (node->GetResources().size() == 0) {
return false;
}
auto resource_type = node->GetResources().at(0).first;
// If the candidate collective has more than 1 overlapping ranks with
// in-flight collectives, they can form cyclic dependency and cannot be
// overlapped
if ((resource_type - AsyncTracker::GetFirstTargetDefinedResource()) ==
static_cast<int64_t>(GpuResourceType::kGpuAsyncStreamCollectives) &&
sched_state.resource_occupiers_in_flight.contains(resource_type) &&
sched_state.resource_occupiers_in_flight.at(resource_type).size() > 0) {
const HloInstruction& curr_hlo_inst = node->GetInstr();
if (hlo_query::IsAsyncCollectiveDoneOp(&curr_hlo_inst, true)) {
CHECK(
hlo_query::IsAsyncCollectiveStartOp(curr_hlo_inst.operand(0), true));
const HloInstruction* curr_start_inst =
curr_hlo_inst.operand(0)->async_wrapped_instruction();

// If candidate can be overlapped with in-flight collectives
bool can_overlap = true;
for (const auto occupier :
sched_state.resource_occupiers_in_flight.at(resource_type)) {
if (hlo_query::IsAsyncCollectiveStartOp(occupier, true)) {
// Number of overlapping ranks between this occupier and candidate
size_t overlapping_count = CountOverlappingRanks(
curr_start_inst->replica_groups(), occupier->replica_groups());
if (overlapping_count > 1) {
can_overlap = false;
VLOG(3) << "Collectives have " << overlapping_count
<< "overlapping ranks and cannot be overlapped. Candidate "
"collective: "
<< curr_start_inst->ToString()
<< ", in flight collective: " << occupier->ToString();
break;
}
}
}
if (!can_overlap) return true;
}
}

return false;
}

//===--------------------------------------------------------------------===//
// GpuAsyncTrackerBase
//===--------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions xla/service/gpu/gpu_latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo);
// Returns size of the `shape` given the `pointer_size`.
int64_t GetSizeOfShape(const Shape& shape, int pointer_size);

// GPU overlap limit rule rule for scheduling candidate.
// On top of the default rule, we do not allow collectives with more than 1
// overlapping ranks to overlap. This is because the execution order of NCCL
// kernels is not deterministic and cannot be controlled by launch order at the
// moment. A cyclic dependency can be formed with at least 2 overlapping ranks.
bool GpuScheduleCrossesOverlapLimit(
const DefaultSchedulerCore::SchedulingState& sched_state,
const HloGraphNode* node);

// GPU specific resources for latency hiding scheduler.
//
// We use two different set of resources to model the scheduling of asynchronous
Expand Down
56 changes: 55 additions & 1 deletion xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest,
std::vector<HloInstruction*> instruction_sequence =
schedule.sequence(module->entry_computation()).instructions();
// Since we allow 2 collectives in-flight, we should expect this pattern:
// ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> ar(rs)-done
// ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> rs(ar)-done
EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_0") <
GetIndexByName(instruction_sequence, "rs_1") &&
GetIndexByName(instruction_sequence, "rs_0") <
Expand All @@ -386,5 +386,59 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest,
GetIndexByName(instruction_sequence, "rs_1"));
}

TEST_F(GpuLatencyHidingSchedulerBaseTest,
OverlappingRanksPreventOverlappingCollectives) {
absl::string_view kFdoProfile = R"pb(
costs { name: "add_0" cost_us: 100000.0 }
costs { name: "ar_0" cost_us: 10.0 }
costs { name: "rs_0" cost_us: 10.0 }
)pb";
;
absl::string_view kHloModule = R"(
HloModule m
reduce {
x = f32[] parameter(0)
y = f32[] parameter(1)
ROOT _ = f32[] add(x, y)
}
ENTRY main {
p0 = f32[] parameter(0)
p1 = f32[2] parameter(1)
p2 = f32[2] parameter(2)
ar_0 = f32[] all-reduce-start(p0), to_apply=reduce, replica_groups={{0,1}}
ar_1 = f32[] all-reduce-done(ar_0)
rs_0 = ((f32[2]), f32[1]) reduce-scatter-start(p1), to_apply=reduce, dimensions={0}, replica_groups={{0, 1}}
rs_1 = f32[1] reduce-scatter-done(rs_0)
add_0 = f32[2] add(p1, p2)
ROOT _ = (f32[], f32[1], f32[2]) tuple(ar_1, rs_1, add_0)
}
)";

auto config = GetModuleConfig(kFdoProfile);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloModule, config));

TF_EXPECT_OK(ScheduleModule(module.get(), /*num_parallel_resources=*/2));
auto schedule = module->schedule();
std::vector<HloInstruction*> instruction_sequence =
schedule.sequence(module->entry_computation()).instructions();
// AR and RS have two ranks in common so cannot be overlapped, expect pattern:
// rs(ar)-start -> add -> rs(ar)-done -> ar(rs)-start -> ar(rs)-done
EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_1") <
GetIndexByName(instruction_sequence, "rs_0") ||
GetIndexByName(instruction_sequence, "rs_1") <
GetIndexByName(instruction_sequence, "ar_0"));
EXPECT_TRUE((GetIndexByName(instruction_sequence, "ar_0") <
GetIndexByName(instruction_sequence, "add_0") &&
GetIndexByName(instruction_sequence, "add_0") <
GetIndexByName(instruction_sequence, "ar_1")) ||
(GetIndexByName(instruction_sequence, "rs_0") <
GetIndexByName(instruction_sequence, "add_0") &&
GetIndexByName(instruction_sequence, "add_0") <
GetIndexByName(instruction_sequence, "rs_1")));
}

} // namespace
} // namespace xla::gpu
16 changes: 12 additions & 4 deletions xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,6 @@ cc_library(
tags = tf_cuda_tests_tags(),
deps = [
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/strings",
],
Expand All @@ -810,7 +809,10 @@ xla_test(
data = ["test_autotune_cache.textproto"],
env = {"XLA_FLAGS": "--xla_gpu_load_autotune_results_from=" +
"$(execpath test_autotune_cache.textproto)"},
deps = [":simple_optimization_test"],
deps = [
":simple_optimization_test",
"//xla/tests:xla_internal_test_main",
],
)

# This shows that tests can load an autotune cache using the TEST_WORKSPACE prefix.
Expand All @@ -829,7 +831,10 @@ xla_test(
env = {"XLA_FLAGS": "--xla_gpu_load_autotune_results_from=TEST_WORKSPACE/" +
package_name() +
"/test_autotune_cache.textproto"},
deps = [":simple_optimization_test"],
deps = [
":simple_optimization_test",
"//xla/tests:xla_internal_test_main",
],
)

# This shows that tests can dump an autotune cache into their output directory.
Expand All @@ -845,7 +850,10 @@ xla_test(
"TEST_UNDECLARED_OUTPUTS_DIR/autotune_cache.textproto"},
# Sharding must be disabled to correctly dump the autotune cache for all test.
shard_count = 1,
deps = [":simple_optimization_test"],
deps = [
":simple_optimization_test",
"//xla/tests:xla_internal_test_main",
],
)

xla_test(
Expand Down
68 changes: 43 additions & 25 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/map_util.h"
#include "xla/service/dump.h"
#include "xla/service/hlo_buffer.h"
Expand Down Expand Up @@ -1138,6 +1139,8 @@ class ReadySetLt {
const DefaultSchedulerCore::SchedulingState& sched_state_;
DefaultSchedulerCore::TargetSchedulingRule target_scheduling_rule_;
DefaultSchedulerCore::TargetSchedulingRule early_target_scheduling_rule_;
DefaultSchedulerCore::OverlapLimitRule
scheduling_instruction_crosses_overlap_limit_;

int ReadyIfScheduled(const HloGraphNode& gn) const {
int ready_nodes_if_scheduled = 0;
Expand Down Expand Up @@ -1271,9 +1274,9 @@ class ReadySetLt {
cand.node->GetResources());
int64_t num_conflicting_resources = 0;
for (int64_t resource : resources) {
if (!sched_state_.resources_in_flight.contains(resource)) continue;
if (!sched_state_.resource_occupiers_in_flight.count(resource)) continue;
num_conflicting_resources +=
sched_state_.resources_in_flight.at(resource);
sched_state_.resource_occupiers_in_flight.at(resource).size();
}
return num_conflicting_resources;
}
Expand Down Expand Up @@ -1312,26 +1315,29 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
}
absl::InlinedVector<std::pair<HloGraphNode*, SkipNodeReason>, 2>
skipped_nodes_and_reasons;
auto scheduling_instruction_crosses_overlap_limit =
[&sched_state](const HloInstruction& instr) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resources_in_flight.find(resource);
if (it == sched_state.resources_in_flight.end() || it->second == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was to
// be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(resource,
instr);
if (limit < num_resources_needed) {
return true;
if (!scheduling_instruction_crosses_overlap_limit_) {
scheduling_instruction_crosses_overlap_limit_ =
[](const SchedulingState& sched_state, const HloGraphNode* node) {
for (const auto& [resource, limit] :
sched_state.max_concurrent_resource) {
// No resources in flight of this kind. Continue.
auto it = sched_state.resource_occupiers_in_flight.find(resource);
if (it == sched_state.resource_occupiers_in_flight.end() ||
it->second.size() == 0) {
continue;
}
// Number of instances of 'resource' needed if this instruction was
// to be scheduled.
const int64_t num_resources_needed =
sched_state.async_tracker->GetNumResourcesPerInstruction(
resource, node->GetInstr());
if (limit < num_resources_needed) {
return true;
}
}
}
return false;
};
return false;
};
}
VLOG(2) << "Current time: " << sched_state.current_time;
ReadySetLt ready_lt{&sched_state, target_scheduling_rule_,
early_target_scheduling_rule_};
Expand Down Expand Up @@ -1363,8 +1369,8 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable(
}
// If this node would cause the max_concurrent_resource count to go beyond
// the limit do not schedule it and pass to the next node.
if (scheduling_instruction_crosses_overlap_limit(
(*ready_node_it)->GetInstr())) {
if (scheduling_instruction_crosses_overlap_limit_(sched_state,
*ready_node_it)) {
if (ready_chosen.node == nullptr) {
skipped_nodes_and_reasons.push_back(
{*ready_node_it, SkipNodeReason::kExceedsOverlapLimit});
Expand Down Expand Up @@ -1902,9 +1908,21 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
++sched_state->scheduled_count;
for (auto& resource : n->GetResources()) {
if (resource.second == ResourceUsageType::kResourceRelease) {
--sched_state->resources_in_flight[resource.first];
sched_state->resource_occupiers_in_flight.at(resource.first)
.erase(&n->GetInstr());
} else if (resource.second == ResourceUsageType::kResourceOccupy) {
++sched_state->resources_in_flight[resource.first];
// For async collective done ops, save their corresponding start ops to
// the map
if (hlo_query::IsAsyncCollectiveDoneOp(&n->GetInstr(),
/*include_send_recv=*/true)) {
CHECK(hlo_query::IsAsyncCollectiveStartOp(n->GetInstr().operand(0),
true));
sched_state->resource_occupiers_in_flight[resource.first].insert(
n->GetInstr().operand(0));
} else {
sched_state->resource_occupiers_in_flight[resource.first].insert(
&n->GetInstr());
}
}
}
VLOG(10) << "Memory pressure before schedule: "
Expand Down
Loading

0 comments on commit 1a2606f

Please sign in to comment.