From 6fc61c5386ef1b8ddda616acb363a66a55f94da1 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Wed, 20 Nov 2024 09:22:44 -0800 Subject: [PATCH] [XLA:GPU] Adjust GetNumWarps heuristic in Tiled Cost Model. We need to adjust the heuristic because before our emitter had an issue that prevented Triton from doing proper layout optimizations. It was fixed in https://github.com/openxla/xla/commit/7280b9ad5fe1433baadf34ce9b59ffbaf607603f. We needed to use higher number of warps (up to 32) before to cover the lack of layout optimization, but now it can cause performance regressions, because Triton likes to insert shmem usage and barrier syncs. PiperOrigin-RevId: 698416298 --- .../model/gpu_indexing_performance_model.cc | 43 +++++++++++++++---- .../gpu_indexing_performance_model_test.cc | 10 ++--- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/xla/service/gpu/model/gpu_indexing_performance_model.cc b/xla/service/gpu/model/gpu_indexing_performance_model.cc index 01f85c0e2baa4..2118acc757c14 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -127,16 +127,41 @@ bool DoesTileFitsInRegisters(int64_t tile_size, device_info.registers_per_block_limit(); } -// Returns the number of warps to use based on the tile size. The numbers were -// originally selected from Triton SoftMax reduction row length. +// Returns the number of warps to use based on the largest tile size in the +// computation. +// +// This is a simple heuristic and we try to make minimal assumptions about the +// kernel that will be generated by a block-level emitter, but there are a few +// things we take into consideration. +// +// For smaller tile sizes, we pick less warps to make sure there is enough +// elements per thread to have vectorized loads and stores. +// +// For larger tiles, we don't know how many registers will be live at the same +// time and how much shared memory will be used, but there is a good chance that +// only one block will be able to reside on an SM at any given moment. +// +// Choosing 4 or less warps for a large tile will have the following problems: +// +// * Not all register will be utilized. On H100, for example, there are 64K +// registers available per SM in total, but there is also a limit of 255 +// registers per thread. To be able to use all available registers we +// need at least 64K / 255 = 256 threads = 8 warps. +// * Not enough parallelism to overlap compute and memory access. +// +// Choosing more than 8 warps can also cause performance regressions: +// * If layout optimizations in a block-level emitter will decide to use +// shared memory and insert barrier syncs to perform reduction or reduce +// amount of HBM traffic. +// +// These values and thresholds were empirically determined in November 2024 and +// may change in the future. // TODO(b/332714755): Make it smarter. -int64_t GetNumWarps(int64_t tile_size) { - if (tile_size <= 256) return 1; - if (tile_size <= 512) return 2; - if (tile_size <= 1024) return 4; - if (tile_size <= 2048) return 8; - if (tile_size <= 4096) return 16; - return 32; +int64_t GetNumWarps(int64_t largest_live_tile_size) { + if (largest_live_tile_size <= 256) return 1; + if (largest_live_tile_size <= 1024) return 2; + if (largest_live_tile_size <= 4096) return 4; + return 8; } } // namespace diff --git a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index f9ed7a1b355dd..bb6653e187ed1 100644 --- a/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -330,7 +330,7 @@ ENTRY main { EXPECT_THAT(tiled_runtime_data.block_level_parameters.output_tile_sizes, ElementsAre(4, 911)); - EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 16); + EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 4); EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_read, kExpectedBytesRead); EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_written, kOutputSizeBytes); @@ -649,8 +649,8 @@ ENTRY main { // Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate // the number of warps for padded tile that has size of 16 * 16 * 16 = 4096 - // and corresponds to 16 warps. - EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize()); + // and corresponds to 4 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); } TEST_F(GpuIndexingPerformanceModelTest, @@ -696,8 +696,8 @@ ENTRY main { EXPECT_EQ(launch_dimensions.num_blocks(), 1); // The largest tile size is 1 * 4096, for which our implementation recommends - // using 16 warps. - EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize()); + // using 4 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); } class FlopsPerElementTest : public GpuIndexingPerformanceModelTest {