Skip to content

Commit ff5f61c

Browse files
beckerheGoogle-ML-Automation
authored andcommitted
Replace gpu_asm_extra_flags string option by individual flags
`gpu_asm_extra_flags` allowed the user to pass arbitrary command line flags to `ptxas`. This was working well enough when calling ptxas was our only PTX compilation option. With library compilation and support for parallel compilation a wrong flag can easily mess up the PTX compilation with weird error. It also doesn't support flags for `nvlink` which might be needed when using parallel compilation. So I'm replacing the list of opaque flags by individual boolean options for the two main use cases (debug compile and preserving line info). The new compilation providers will take those booleans options and generate the right flags or API calls. As a temporary shim the function `PtxOptsFromDebugOptions` will still generate the command line flags for users of the legacy PTX compilation functions. PiperOrigin-RevId: 701850432
1 parent 41e12cc commit ff5f61c

File tree

5 files changed

+87
-15
lines changed

5 files changed

+87
-15
lines changed

xla/debug_options_flags.cc

+12-6
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
5858
opts.set_xla_gpu_autotune_max_solutions(0);
5959
opts.set_xla_cpu_multi_thread_eigen(true);
6060
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
61-
opts.set_xla_gpu_asm_extra_flags("");
61+
opts.set_xla_gpu_generate_debug_info(false);
62+
opts.set_xla_gpu_generate_line_info(false);
6263

6364
// As of cudnn 8.9.0, runtime fusion creates convolutions that take about 7s
6465
// seconds to run the first time we call them, at least on Ampere. In
@@ -968,11 +969,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
968969
bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations),
969970
debug_options->xla_gpu_disable_gpuasm_optimizations(),
970971
"In XLA:GPU run ptxas in -O0 (default is -O3)."));
971-
flag_list->push_back(tsl::Flag(
972-
"xla_gpu_asm_extra_flags",
973-
string_setter_for(&DebugOptions::set_xla_gpu_asm_extra_flags), "",
974-
"Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA). "
975-
"If multiple parameters, separate them by comma."));
972+
flag_list->push_back(
973+
tsl::Flag("xla_gpu_generate_debug_info",
974+
bool_setter_for(&DebugOptions::set_xla_gpu_generate_debug_info),
975+
debug_options->xla_gpu_generate_debug_info(),
976+
"Generate debug info for codegened CUDA kernels."));
977+
flag_list->push_back(
978+
tsl::Flag("xla_gpu_generate_line_info",
979+
bool_setter_for(&DebugOptions::set_xla_gpu_generate_line_info),
980+
debug_options->xla_gpu_generate_line_info(),
981+
"Generate line info for codegened CUDA kernels."));
976982
flag_list->push_back(tsl::Flag(
977983
"xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"",
978984
"Sets compiler fuel, useful for bisecting bugs in passes. Format "

xla/service/gpu/BUILD

+12-2
Original file line numberDiff line numberDiff line change
@@ -2406,11 +2406,21 @@ cc_library(
24062406
srcs = ["gpu_asm_opts_util.cc"],
24072407
hdrs = ["gpu_asm_opts_util.h"],
24082408
compatible_with = get_compatible_with_portable(),
2409-
copts = tsl_copts(),
24102409
deps = [
24112410
"//xla:xla_proto_cc",
24122411
"//xla/stream_executor/gpu:gpu_asm_opts",
2413-
"@com_google_absl//absl/strings",
2412+
],
2413+
)
2414+
2415+
xla_cc_test(
2416+
name = "gpu_asm_opts_util_test",
2417+
srcs = ["gpu_asm_opts_util_test.cc"],
2418+
deps = [
2419+
":gpu_asm_opts_util",
2420+
"//xla:xla_proto_cc",
2421+
"//xla/tests:xla_internal_test_main",
2422+
"@com_google_googletest//:gtest_main",
2423+
"@tsl//tsl/platform:test",
24142424
],
24152425
)
24162426

xla/service/gpu/gpu_asm_opts_util.cc

+9-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ limitations under the License.
1616
#include "xla/service/gpu/gpu_asm_opts_util.h"
1717

1818
#include <string>
19+
#include <utility>
1920
#include <vector>
2021

21-
#include "absl/strings/str_split.h"
2222
#include "xla/stream_executor/gpu/gpu_asm_opts.h"
2323
#include "xla/xla.pb.h"
2424

@@ -27,11 +27,16 @@ namespace gpu {
2727

2828
stream_executor::GpuAsmOpts PtxOptsFromDebugOptions(
2929
const DebugOptions& debug_options) {
30-
std::vector<std::string> extra_flags = absl::StrSplit(
31-
debug_options.xla_gpu_asm_extra_flags(), ',', absl::SkipEmpty());
30+
std::vector<std::string> extra_flags;
31+
if (debug_options.xla_gpu_generate_line_info()) {
32+
extra_flags.emplace_back("--generate-line-info");
33+
}
34+
if (debug_options.xla_gpu_generate_debug_info()) {
35+
extra_flags.emplace_back("--device-debug");
36+
}
3237
return stream_executor::GpuAsmOpts(
3338
debug_options.xla_gpu_disable_gpuasm_optimizations(),
34-
debug_options.xla_gpu_cuda_data_dir(), extra_flags);
39+
debug_options.xla_gpu_cuda_data_dir(), std::move(extra_flags));
3540
}
3641

3742
} // namespace gpu
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/gpu_asm_opts_util.h"
17+
18+
#include <gmock/gmock.h>
19+
#include <gtest/gtest.h>
20+
#include "xla/xla.pb.h"
21+
#include "tsl/platform/test.h"
22+
23+
namespace xla::gpu {
24+
namespace {
25+
using ::testing::Contains;
26+
27+
TEST(PtxOptsFromDebugOptionsTest, GenerateLineInfo) {
28+
xla::DebugOptions debug_options;
29+
debug_options.set_xla_gpu_generate_line_info(true);
30+
31+
EXPECT_THAT(PtxOptsFromDebugOptions(debug_options).extra_flags,
32+
Contains("--generate-line-info"));
33+
}
34+
35+
TEST(PtxOptsFromDebugOptionsTest, GenerateDebugInfo) {
36+
xla::DebugOptions debug_options;
37+
debug_options.set_xla_gpu_generate_debug_info(true);
38+
39+
EXPECT_THAT(PtxOptsFromDebugOptions(debug_options).extra_flags,
40+
Contains("--device-debug"));
41+
}
42+
43+
} // namespace
44+
} // namespace xla::gpu

xla/xla.proto

+10-3
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,14 @@ message DebugOptions {
314314
// If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3).
315315
bool xla_gpu_disable_gpuasm_optimizations = 103;
316316

317+
// If true, we generate debug info when compiling PTX. This is useful for
318+
// profiling and debugging.
319+
bool xla_gpu_generate_debug_info = 348;
320+
321+
// If true, we generate line info when compiling PTX. This is useful for
322+
// profiling and debugging.
323+
bool xla_gpu_generate_line_info = 349;
324+
317325
enum ShapeChecks {
318326
// Do not insert any shape checks for dynamically shaped operations; output
319327
// buffers might contain garbage data if shapes don't match.
@@ -455,8 +463,7 @@ message DebugOptions {
455463
// memory, or have bugs.
456464
bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138;
457465

458-
// Extra parameters to pass the GPU assembler.
459-
string xla_gpu_asm_extra_flags = 141;
466+
reserved 141; // was xla_gpu_asm_extra_flags
460467

461468
// Per-heap size constraint. New heaps will be created if per-heap max size is
462469
// reached.
@@ -1062,7 +1069,7 @@ message DebugOptions {
10621069
// be deterministic, although with additional overhead.
10631070
bool xla_gpu_enable_scatter_determinism_expander = 345;
10641071

1065-
// Next id: 348
1072+
// Next id: 350
10661073

10671074
// Extra options to pass to the compilation backend (e.g. LLVM); specific
10681075
// interpretation of these values is left to the backend.

0 commit comments

Comments
 (0)