Skip to content

Commit

Permalink
Limit the number of additional ExecutionStreamIds to 4.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663524714
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Aug 16, 2024
1 parent c4657b6 commit 705f572
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
13 changes: 9 additions & 4 deletions xla/service/gpu/execution_stream_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ limitations under the License.

namespace xla::gpu {

ExecutionStreamAssignment::ExecutionStreamAssignment(const HloModule* module) {
ExecutionStreamAssignment::ExecutionStreamAssignment(
const HloModule* module, ExecutionStreamAssignmentOptions options) {
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);

// We'll walk the `CallGraph` starting from the entrypoint. The instructions
Expand Down Expand Up @@ -88,14 +89,18 @@ ExecutionStreamAssignment::ExecutionStreamAssignment(const HloModule* module) {
// Asynchronous calls will result in a new `ExecutionStreamId` being
// dispensed for the called computations.
CHECK_EQ(callsite.instruction()->opcode(), HloOpcode::kAsyncStart);
const ExecutionStreamId async_stream_id = next_stream_id++;
enqueue_called_computations(callsite, async_stream_id);
enqueue_called_computations(callsite, next_stream_id);

AsyncExecutionStreamIds streams;
streams.source_stream_id = pending.stream_id;
streams.destination_stream_id = async_stream_id;
streams.destination_stream_id = next_stream_id;
CHECK(async_instructions_.try_emplace(callsite.instruction(), streams)
.second);

next_stream_id++;
if (next_stream_id.value() > options.number_of_execution_streams) {
next_stream_id = ExecutionStreamId(1);
}
} else {
// Synchronous calls will result in the called computations being
// invoked using the same `ExecutionStreamId`.
Expand Down
9 changes: 8 additions & 1 deletion xla/service/gpu/execution_stream_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ limitations under the License.

namespace xla::gpu {

struct ExecutionStreamAssignmentOptions {
// The `ExecutionStreamAssignment` will round-robin across this many
// `ExecutionStreams`.
int number_of_execution_streams = 4;
};

// `ExecutionStreamAssignments` represent a mapping from `HloInstructions` to
// `ExecutionStreamIds`. Asynchronous calls (`async-start`, `async-update`, and
// `async-done`) result in the target computations being assigned new
Expand All @@ -37,7 +43,8 @@ class ExecutionStreamAssignment {
// pass the module through the `FlattenCallGraph` pass.
//
// The ExecutionStreamAssignment does not take ownership of the `HloModule`.
explicit ExecutionStreamAssignment(const HloModule* module);
explicit ExecutionStreamAssignment(
const HloModule* module, ExecutionStreamAssignmentOptions options = {});

// Returns the `ExecutionStreamId` for the given instruction, which *must* be
// synchronous. Returns an error if the instruction is either not reachable
Expand Down
26 changes: 23 additions & 3 deletions xla/service/gpu/execution_stream_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) {
p0 = f32[2,2] parameter(0)
ROOT add = f32[2,2] add(p0, p0)
}
leaf3 {
p0 = f32[2,2] parameter(0)
ROOT add = f32[2,2] add(p0, p0)
}
// Entry computation that calls each of the leaves asynchronously.
ENTRY entry {
Expand All @@ -77,21 +81,30 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) {
kind=kLoop, calls=leaf1
start2 = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0),
kind=kLoop, calls=leaf2
start3 = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0),
kind=kLoop, calls=leaf3
update1 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start1)
update2 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start2)
update3 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start3)
done1 = f32[2,2] fusion-done(update1)
done2 = f32[2,2] fusion-done(update2)
ROOT done = f32[2,2] add(done1, done2)
done3 = f32[2,2] fusion-done(update3)
ROOT done = f32[2,2] custom-call(done1, done2, done3),
custom_call_target="target"
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kModuleStr));

ExecutionStreamAssignment assignment(module.get());
ExecutionStreamAssignment assignment(
module.get(),
ExecutionStreamAssignmentOptions{/*number_of_execution_streams=*/2});

// The outermost computation should run on `ExecutionStreamId(0)`. The two
// asynchronous branches should be launched on `ExecutionStreamId(1)` and
// `ExecutionStreamId(2)`, respectively.
// `ExecutionStreamId(2)`, respectively. The third asynchronous branch should
// reuse `ExecutionStreamId(1)` because we set `number_of_execution_streams`
// to `2`.
ExpectExecutionStreamForSyncInstructions(
assignment, FindComputation(module.get(), "entry"), ExecutionStreamId(0));
for (std::string_view instruction : {"start1", "update1", "done1"}) {
Expand All @@ -108,6 +121,13 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) {
/*source_stream_id=*/ExecutionStreamId(0),
/*destination_stream_id=*/ExecutionStreamId(2)}));
}
for (std::string_view instruction : {"start3", "update3", "done3"}) {
EXPECT_THAT(assignment.GetAsyncExecutionStreamIds(Cast<HloAsyncInstruction>(
FindInstruction(module.get(), instruction))),
IsOkAndHolds(AsyncExecutionStreamIds{
/*source_stream_id=*/ExecutionStreamId(0),
/*destination_stream_id=*/ExecutionStreamId(1)}));
}

// Leaf computations should run on the respective asynchronous
// `ExecutionStreamIds`.
Expand Down

0 comments on commit 705f572

Please sign in to comment.