Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Adjust GetNumWarps heuristic in Tiled Cost Model. #19540

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[XLA:GPU] Adjust GetNumWarps heuristic in Tiled Cost Model.
We need to adjust the heuristic because before our emitter had an issue that prevented Triton from doing proper layout optimizations. It was fixed in 7280b9a.

We needed to use higher number of warps (up to 32) before to cover the lack of layout optimization, but now it can cause performance regressions, because Triton likes to insert shmem usage and barrier syncs.

PiperOrigin-RevId: 698416298
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Nov 20, 2024
commit 6fc61c5386ef1b8ddda616acb363a66a55f94da1
43 changes: 34 additions & 9 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,41 @@ bool DoesTileFitsInRegisters(int64_t tile_size,
device_info.registers_per_block_limit();
}

// Returns the number of warps to use based on the tile size. The numbers were
// originally selected from Triton SoftMax reduction row length.
// Returns the number of warps to use based on the largest tile size in the
// computation.
//
// This is a simple heuristic and we try to make minimal assumptions about the
// kernel that will be generated by a block-level emitter, but there are a few
// things we take into consideration.
//
// For smaller tile sizes, we pick less warps to make sure there is enough
// elements per thread to have vectorized loads and stores.
//
// For larger tiles, we don't know how many registers will be live at the same
// time and how much shared memory will be used, but there is a good chance that
// only one block will be able to reside on an SM at any given moment.
//
// Choosing 4 or less warps for a large tile will have the following problems:
//
// * Not all register will be utilized. On H100, for example, there are 64K
// registers available per SM in total, but there is also a limit of 255
// registers per thread. To be able to use all available registers we
// need at least 64K / 255 = 256 threads = 8 warps.
// * Not enough parallelism to overlap compute and memory access.
//
// Choosing more than 8 warps can also cause performance regressions:
// * If layout optimizations in a block-level emitter will decide to use
// shared memory and insert barrier syncs to perform reduction or reduce
// amount of HBM traffic.
//
// These values and thresholds were empirically determined in November 2024 and
// may change in the future.
// TODO(b/332714755): Make it smarter.
int64_t GetNumWarps(int64_t tile_size) {
if (tile_size <= 256) return 1;
if (tile_size <= 512) return 2;
if (tile_size <= 1024) return 4;
if (tile_size <= 2048) return 8;
if (tile_size <= 4096) return 16;
return 32;
int64_t GetNumWarps(int64_t largest_live_tile_size) {
if (largest_live_tile_size <= 256) return 1;
if (largest_live_tile_size <= 1024) return 2;
if (largest_live_tile_size <= 4096) return 4;
return 8;
}

} // namespace
Expand Down
10 changes: 5 additions & 5 deletions xla/service/gpu/model/gpu_indexing_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ ENTRY main {

EXPECT_THAT(tiled_runtime_data.block_level_parameters.output_tile_sizes,
ElementsAre(4, 911));
EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 16);
EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 4);

EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_read, kExpectedBytesRead);
EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_written, kOutputSizeBytes);
Expand Down Expand Up @@ -649,8 +649,8 @@ ENTRY main {

// Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate
// the number of warps for padded tile that has size of 16 * 16 * 16 = 4096
// and corresponds to 16 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize());
// and corresponds to 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

TEST_F(GpuIndexingPerformanceModelTest,
Expand Down Expand Up @@ -696,8 +696,8 @@ ENTRY main {
EXPECT_EQ(launch_dimensions.num_blocks(), 1);

// The largest tile size is 1 * 4096, for which our implementation recommends
// using 16 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize());
// using 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {
Expand Down
Loading