Skip to content

Commit 700d976

Browse files
changhuilinGoogle-ML-Automation
authored andcommitted
[MPMD-GPU] Add is_subslice_topology to the IFRT Topology.
This is to make IFRT Topology fields consistent with `PjRtTopologyDescriptionProto` fields. PiperOrigin-RevId: 694647266
1 parent 262533b commit 700d976

File tree

9 files changed

+43
-169
lines changed

9 files changed

+43
-169
lines changed

xla/debug_options_flags.cc

+10
Original file line numberDiff line numberDiff line change
@@ -2050,6 +2050,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
20502050
bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
20512051
debug_options->xla_enable_fast_math(),
20522052
"Enable optimizations that assume finite math, i.e., no NaN."));
2053+
flag_list->push_back(tsl::Flag(
2054+
"xla_experimental_exec_time_optimization_effort",
2055+
float_setter_for(
2056+
&DebugOptions::set_xla_experimental_exec_time_optimization_effort),
2057+
debug_options->xla_experimental_exec_time_optimization_effort(),
2058+
"The execution time optimization effort to expend during compilation. "
2059+
"Takes range [-1.0, 1.0] where values < 0.0 indicate skipping passes "
2060+
"which might optimize the final runtime (thus improving compile time), "
2061+
"and values > 0.0 indicate running additional passes which may improve "
2062+
"runtime at the cost of compilation time."));
20532063
flag_list->push_back(tsl::Flag(
20542064
"xla_gpu_experimental_parallel_collective_overlap_limit",
20552065
int32_setter_for(

xla/python/ifrt/topology.h

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Topology : public llvm::RTTIExtends<Topology, llvm::RTTIRoot> {
4444

4545
virtual PjRtPlatformId platform_id() const = 0;
4646

47+
virtual bool is_subslice_topology() const = 0;
48+
4749
// Returns an unordered list of descriptions for all devices in this topology.
4850
// TODO(phawkins): consider introducing an IFRT-specific API here instead of
4951
// delegating to PJRT.

xla/python/pjrt_ifrt/pjrt_topology.cc

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ PjRtPlatformId PjRtTopology::platform_id() const {
5151
return description_->platform_id();
5252
}
5353

54+
bool PjRtTopology::is_subslice_topology() const {
55+
return description_->is_subslice_topology();
56+
}
57+
5458
std::vector<std::unique_ptr<const PjRtDeviceDescription>>
5559
PjRtTopology::DeviceDescriptions() const {
5660
return description_->DeviceDescriptions();

xla/python/pjrt_ifrt/pjrt_topology.h

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class PjRtTopology final : public llvm::RTTIExtends<PjRtTopology, Topology> {
4646
absl::string_view platform_name() const override;
4747
absl::string_view platform_version() const override;
4848
PjRtPlatformId platform_id() const override;
49+
bool is_subslice_topology() const override;
4950

5051
std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
5152
const override;

xla/python/xla.cc

+7
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,13 @@ NB_MODULE(xla_extension, m_nb) {
858858
.def_prop_ro(
859859
"platform_version",
860860
[](ifrt::Topology& topology) { return topology.platform_version(); })
861+
.def_prop_ro(
862+
"platform_id",
863+
[](ifrt::Topology& topology) { return topology.platform_id(); })
864+
.def_prop_ro("is_subslice_topology",
865+
[](ifrt::Topology& topology) {
866+
return topology.is_subslice_topology();
867+
})
861868
.def("serialize",
862869
[](ifrt::Topology& topology) -> nb::bytes {
863870
std::string serialized = ValueOrThrow(topology.Serialize());

xla/python/xla_extension/__init__.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,8 @@ class Executable:
738738
class DeviceTopology:
739739
platform: str
740740
platform_version: str
741+
platform_id: int
742+
is_subslice_topology: bool
741743
def _make_compile_only_devices(self) -> List[Device]: ...
742744
def serialize(self) -> bytes: ...
743745
def __getattr__(self, name: str) -> Any: ...

xla/service/gpu/transforms/softmax_rewriter_triton.cc

+3-106
Original file line numberDiff line numberDiff line change
@@ -124,109 +124,6 @@ inline bool HasOneUse(const HloInstruction* instr) {
124124
return instr->user_count() == 1;
125125
}
126126

127-
// Supports two types of broadcast of parameters. Either to one batch
128-
// dim, or one reduction dim. For example the following cases are supported:
129-
//
130-
// Case #1:
131-
// p = f32[a] parameter(0)
132-
// b = f32[a,x] broadcast(p), dimensions={0}
133-
//
134-
// Case #2:
135-
// p = f32[a] parameter(0)
136-
// b = f32[x,a] broadcast(p), dimensions={1}
137-
//
138-
// Case #3:
139-
// p = f32[a,b] parameter(0)
140-
// b = f32[x,a,b] broadcast(p), dimensions={1,2}
141-
//
142-
// Other broadcast tiling patterns are currently unsupported.
143-
// See b/328049138 for details.
144-
//
145-
// Unsupported case #1:
146-
// p = f32[a] parameter(0)
147-
// b = f32[x,a,y] broadcast(p), dimensions={1}
148-
//
149-
// Unsupported case #2:
150-
// p = f32[a,b] parameter(0)
151-
// b = f32[x,a,y,b] broadcast(p), dimensions={1,3}
152-
//
153-
// Unsupported case #3:
154-
// p = f32[a] parameter(0)
155-
// b = f32[x,y,a] broadcast(p), dimensions={2}
156-
//
157-
// Unsupported case #4:
158-
// p = f32[a,b] parameter(0)
159-
// b = f32[a,x,b] broadcast(p), dimensions={0,2}
160-
//
161-
// Unsupported case #5:
162-
// p = f32[] parameter(0)
163-
// b = f32[x] broadcast(p), dimensions={}
164-
bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) {
165-
CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
166-
<< "Expected broadcast " << hlo.ToShortString();
167-
CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
168-
<< "Expected parameter " << hlo.operand(0)->ToShortString();
169-
170-
const HloBroadcastInstruction* broadcast =
171-
Cast<HloBroadcastInstruction>(&hlo);
172-
173-
const HloParameterInstruction* parameter =
174-
Cast<HloParameterInstruction>(hlo.operand(0));
175-
176-
// Support only one dim broadcast. Scalar parameters are handled elsewhere.
177-
if (broadcast->dimensions().empty() ||
178-
parameter->shape().dimensions_size() + 1 !=
179-
broadcast->shape().dimensions_size()) {
180-
return false;
181-
}
182-
183-
// It is enough to ensure that the broadcast does not preserve both last, and
184-
// first dimensions of the parameter at the same time. Otherwise the broadcast
185-
// is the unsupported case #4.
186-
//
187-
// Preserve the first dim:
188-
// p = f32[a,b] parameter(0)
189-
// b1 = f32[a,b,c] broadcast(p), dimensions={0,1}
190-
bool preserve_first_dim = broadcast->dimensions().front() == 0;
191-
// Preserve the last dim:
192-
// p = f32[a,b] parameter(0)
193-
// b1 = f32[c,a,b] broadcast(p), dimensions={1,2}
194-
bool preserve_last_dim = broadcast->dimensions().back() ==
195-
broadcast->shape().dimensions_size() - 1;
196-
// We do not want to preserve both first and last dim, as it means the
197-
// broadcast is not expanding on outermost dims.
198-
return !(preserve_first_dim && preserve_last_dim);
199-
}
200-
201-
bool IsBroadcastOfAScalar(const HloInstruction& hlo) {
202-
CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
203-
<< "Expected broadcast " << hlo.ToShortString();
204-
return ShapeUtil::IsScalar(hlo.operand(0)->shape());
205-
}
206-
207-
bool IsSingleRowParameterBroadcast(const HloInstruction& hlo) {
208-
CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast)
209-
<< "Expected broadcast " << hlo.ToShortString();
210-
CHECK_EQ(hlo.operand(0)->opcode(), HloOpcode::kParameter)
211-
<< "Expected parameter " << hlo.operand(0)->ToShortString();
212-
213-
const HloBroadcastInstruction* broadcast =
214-
Cast<HloBroadcastInstruction>(&hlo);
215-
const HloParameterInstruction* parameter =
216-
Cast<HloParameterInstruction>(hlo.operand(0));
217-
218-
if (parameter->shape().dimensions_size() != 1) {
219-
return false;
220-
}
221-
return broadcast->dimensions()[0] == broadcast->shape().dimensions_size() - 1;
222-
}
223-
224-
bool IsSupportedBroadcastOfParameter(const HloInstruction& hlo) {
225-
return IsBroadcastOfParameter(hlo) &&
226-
(IsBatchOrReductionDimBroadcast(hlo) || IsBroadcastOfAScalar(hlo) ||
227-
IsSingleRowParameterBroadcast(hlo));
228-
}
229-
230127
// Chooses which operand to use for fusion processing. Taking in a unary or
231128
// binary instruction, returns the first non-splat operand. If none is
232129
// present, returns any operand.
@@ -238,7 +135,7 @@ HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) {
238135
// broadcast of any op.
239136
if (instr->operand_count() > 1 &&
240137
(IsBroadcastOfScalarConstant(*instr->operand(0)) ||
241-
IsSupportedBroadcastOfParameter(*instr->operand(0)))) {
138+
IsBroadcastOfParameter(*instr->operand(0)))) {
242139
return instr->mutable_operand(1);
243140
}
244141
return instr->mutable_operand(0);
@@ -284,9 +181,9 @@ bool IsTriviallyFusible(HloInstruction* instr,
284181
// TODO(b/326217416): Extend the broadcast of splat constants/parameters to
285182
// a broadcast of any op.
286183
if ((IsBroadcastOfScalarConstant(*operand_0) ||
287-
IsSupportedBroadcastOfParameter(*operand_0)) ^
184+
IsBroadcastOfParameter(*operand_0)) ^
288185
(IsBroadcastOfScalarConstant(*operand_1) ||
289-
IsSupportedBroadcastOfParameter(*operand_1))) {
186+
IsBroadcastOfParameter(*operand_1))) {
290187
return static_cast<bool>(
291188
IsTritonSupportedInstruction(*instr, gpu_version));
292189
}

xla/service/gpu/transforms/softmax_rewriter_triton_test.cc

+5-62
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "xla/service/gpu/transforms/softmax_rewriter_triton.h"
1616

17-
#include <cstdint>
1817
#include <memory>
1918
#include <string>
2019
#include <variant>
@@ -1436,9 +1435,8 @@ ENTRY main {
14361435
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
14371436
}
14381437

1439-
TEST_F(
1440-
SoftmaxRewriterTritonTest,
1441-
DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOf1DParameterAlongBothBatchAndReductionDimensions) { // NOLINT(whitespace/line_length)
1438+
TEST_F(SoftmaxRewriterTritonTest,
1439+
FusesBinaryElementwiseIfIntermediateDiamondOpIsBroadcastOfParameter) {
14421440
const std::string hlo_string = R"(
14431441
HloModule h1
14441442
@@ -1460,67 +1458,12 @@ ENTRY main {
14601458
ROOT add1 = f32[64,32,16]{2,1,0} add(add_0, broadcast_0)
14611459
})";
14621460
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
1463-
EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value());
1464-
}
1465-
1466-
TEST_F(
1467-
SoftmaxRewriterTritonTest,
1468-
DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithBroadcastAlongBatchAndReductionDimAsParameter) { // NOLINT(whitespace/line_length)
1469-
const std::string hlo_string = R"(
1470-
HloModule h1
1471-
1472-
add_computation {
1473-
y = f32[] parameter(1)
1474-
x = f32[] parameter(0)
1475-
ROOT add = f32[] add(x, y)
1476-
}
1477-
1478-
ENTRY main {
1479-
p0 = f32[8]{0} parameter(0)
1480-
p1 = f32[32,8,16]{2,1,0} parameter(1)
1481-
c = f32[] constant(0)
1482-
1483-
r0 = f32[32,8]{1,0} reduce(p1, c), dimensions={2}, to_apply=add_computation
1484-
b0 = f32[32,8,16]{2,1,0} broadcast(r0), dimensions={0,1}
1485-
b1 = f32[32,8,16]{2,1,0} broadcast(p0), dimensions={1}
1486-
add0 = f32[32,8,16]{2,1,0} add(b1, p1)
1487-
ROOT add1 = f32[32,8,16]{2,1,0} add(add0, b0)
1488-
})";
1489-
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
1490-
EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value());
1491-
}
1492-
1493-
TEST_F(
1494-
SoftmaxRewriterTritonTest,
1495-
DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithPartialBroadcastToBatchDim) { // NOLINT(whitespace/line_length)
1496-
const std::string hlo_string = R"(
1497-
HloModule h1
1498-
1499-
add_computation {
1500-
y = f32[] parameter(1)
1501-
x = f32[] parameter(0)
1502-
ROOT add = f32[] add(x, y)
1503-
}
1504-
1505-
ENTRY main {
1506-
p0 = f32[16,64]{1,0} parameter(0)
1507-
p1 = f32[8,16,32,64]{3,2,1,0} parameter(1)
1508-
c = f32[] constant(0)
1509-
1510-
r0 = f32[8,16,32]{2,1,0} reduce(p1, c), dimensions={3}, to_apply=add_computation
1511-
b0 = f32[8,16,32,64]{3,2,1,0} broadcast(r0), dimensions={0,1,2}
1512-
b1 = f32[8,16,32,64]{3,2,1,0} broadcast(p0), dimensions={1,3}
1513-
add0 = f32[8,16,32,64]{3,2,1,0} add(b1, p1)
1514-
ROOT add1 = f32[8,16,32,64]{3,2,1,0} add(add0, b0)
1515-
}
1516-
)";
1517-
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
1518-
EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value());
1461+
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
15191462
}
15201463

15211464
TEST_F(
15221465
SoftmaxRewriterTritonTest,
1523-
DoesNotFuseBinaryElementwiseIfIntermediateDiamondOpWithMultiDimBroadcastAlongBatchDimAsParameter) { // NOLINT(whitespace/line_length)
1466+
FusesBinaryElementwiseIfIntermediateDiamondOpWithMultipleDimensionsAsParameter) { // NOLINT(whitespace/line_length)
15241467
const std::string hlo_string = R"(
15251468
HloModule h1
15261469
@@ -1542,7 +1485,7 @@ ENTRY main {
15421485
ROOT add1 = f32[128,64,32,16]{3,2,1,0} add(add0, b0)
15431486
})";
15441487
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
1545-
EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value());
1488+
EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value());
15461489
}
15471490

15481491
// Triton has a requirement that any tile in the program should not have more

xla/xla.proto

+9-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ message DebugOptions {
4343
//--------------------------------------------------------------------------//
4444
// go/keep-sorted start
4545

46+
// The execution time optimization effort to expend during compilation.
47+
// See `exec_time_optimization_effort` for accepted ranges. This flag will
48+
// override any changes set in `ExecutionOptions`. Most likely this is just a
49+
// temporary measure before https://github.com/jax-ml/jax/issues/24715 is in.
50+
//
51+
// TODO(b/377871215): Check whether we still need this.
52+
float xla_experimental_exec_time_optimization_effort = 346;
53+
4654
// go/keep-sorted end
4755

4856
//--------------------------------------------------------------------------//
@@ -1051,7 +1059,7 @@ message DebugOptions {
10511059
// be deterministic, although with additional overhead.
10521060
bool xla_gpu_enable_scatter_determinism_expander = 345;
10531061

1054-
// Next id: 346
1062+
// Next id: 347
10551063

10561064
// Extra options to pass to the compilation backend (e.g. LLVM); specific
10571065
// interpretation of these values is left to the backend.

0 commit comments

Comments
 (0)