Skip to content

Commit

Permalink
[XLA:GPU] Adjust GetNumWarps heuristic in Tiled Cost Model.
Browse files Browse the repository at this point in the history
We need to adjust the heuristic because before our emitter had an issue that prevented Triton from doing proper layout optimizations. It was fixed in 7280b9a.

We needed to use higher number of warps (up to 32) before to cover the lack of layout optimization, but now it can cause performance regressions, because Triton likes to insert shmem usage and barrier syncs.

PiperOrigin-RevId: 698416298
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Nov 20, 2024
1 parent 85e8f75 commit 6fc61c5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
43 changes: 34 additions & 9 deletions xla/service/gpu/model/gpu_indexing_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,41 @@ bool DoesTileFitsInRegisters(int64_t tile_size,
device_info.registers_per_block_limit();
}

// Returns the number of warps to use based on the tile size. The numbers were
// originally selected from Triton SoftMax reduction row length.
// Returns the number of warps to use based on the largest tile size in the
// computation.
//
// This is a simple heuristic and we try to make minimal assumptions about the
// kernel that will be generated by a block-level emitter, but there are a few
// things we take into consideration.
//
// For smaller tile sizes, we pick less warps to make sure there is enough
// elements per thread to have vectorized loads and stores.
//
// For larger tiles, we don't know how many registers will be live at the same
// time and how much shared memory will be used, but there is a good chance that
// only one block will be able to reside on an SM at any given moment.
//
// Choosing 4 or less warps for a large tile will have the following problems:
//
// * Not all register will be utilized. On H100, for example, there are 64K
// registers available per SM in total, but there is also a limit of 255
// registers per thread. To be able to use all available registers we
// need at least 64K / 255 = 256 threads = 8 warps.
// * Not enough parallelism to overlap compute and memory access.
//
// Choosing more than 8 warps can also cause performance regressions:
// * If layout optimizations in a block-level emitter will decide to use
// shared memory and insert barrier syncs to perform reduction or reduce
// amount of HBM traffic.
//
// These values and thresholds were empirically determined in November 2024 and
// may change in the future.
// TODO(b/332714755): Make it smarter.
int64_t GetNumWarps(int64_t tile_size) {
if (tile_size <= 256) return 1;
if (tile_size <= 512) return 2;
if (tile_size <= 1024) return 4;
if (tile_size <= 2048) return 8;
if (tile_size <= 4096) return 16;
return 32;
int64_t GetNumWarps(int64_t largest_live_tile_size) {
if (largest_live_tile_size <= 256) return 1;
if (largest_live_tile_size <= 1024) return 2;
if (largest_live_tile_size <= 4096) return 4;
return 8;
}

} // namespace
Expand Down
10 changes: 5 additions & 5 deletions xla/service/gpu/model/gpu_indexing_performance_model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ ENTRY main {

EXPECT_THAT(tiled_runtime_data.block_level_parameters.output_tile_sizes,
ElementsAre(4, 911));
EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 16);
EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 4);

EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_read, kExpectedBytesRead);
EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_written, kOutputSizeBytes);
Expand Down Expand Up @@ -649,8 +649,8 @@ ENTRY main {

// Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate
// the number of warps for padded tile that has size of 16 * 16 * 16 = 4096
// and corresponds to 16 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize());
// and corresponds to 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

TEST_F(GpuIndexingPerformanceModelTest,
Expand Down Expand Up @@ -696,8 +696,8 @@ ENTRY main {
EXPECT_EQ(launch_dimensions.num_blocks(), 1);

// The largest tile size is 1 * 4096, for which our implementation recommends
// using 16 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize());
// using 4 warps.
EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize());
}

class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {
Expand Down

0 comments on commit 6fc61c5

Please sign in to comment.