@@ -124,109 +124,6 @@ inline bool HasOneUse(const HloInstruction* instr) {
124
124
return instr->user_count () == 1 ;
125
125
}
126
126
127
- // Supports two types of broadcast of parameters. Either to one batch
128
- // dim, or one reduction dim. For example the following cases are supported:
129
- //
130
- // Case #1:
131
- // p = f32[a] parameter(0)
132
- // b = f32[a,x] broadcast(p), dimensions={0}
133
- //
134
- // Case #2:
135
- // p = f32[a] parameter(0)
136
- // b = f32[x,a] broadcast(p), dimensions={1}
137
- //
138
- // Case #3:
139
- // p = f32[a,b] parameter(0)
140
- // b = f32[x,a,b] broadcast(p), dimensions={1,2}
141
- //
142
- // Other broadcast tiling patterns are currently unsupported.
143
- // See b/328049138 for details.
144
- //
145
- // Unsupported case #1:
146
- // p = f32[a] parameter(0)
147
- // b = f32[x,a,y] broadcast(p), dimensions={1}
148
- //
149
- // Unsupported case #2:
150
- // p = f32[a,b] parameter(0)
151
- // b = f32[x,a,y,b] broadcast(p), dimensions={1,3}
152
- //
153
- // Unsupported case #3:
154
- // p = f32[a] parameter(0)
155
- // b = f32[x,y,a] broadcast(p), dimensions={2}
156
- //
157
- // Unsupported case #4:
158
- // p = f32[a,b] parameter(0)
159
- // b = f32[a,x,b] broadcast(p), dimensions={0,2}
160
- //
161
- // Unsupported case #5:
162
- // p = f32[] parameter(0)
163
- // b = f32[x] broadcast(p), dimensions={}
164
- bool IsBatchOrReductionDimBroadcast (const HloInstruction& hlo) {
165
- CHECK_EQ (hlo.opcode (), HloOpcode::kBroadcast )
166
- << " Expected broadcast " << hlo.ToShortString ();
167
- CHECK_EQ (hlo.operand (0 )->opcode (), HloOpcode::kParameter )
168
- << " Expected parameter " << hlo.operand (0 )->ToShortString ();
169
-
170
- const HloBroadcastInstruction* broadcast =
171
- Cast<HloBroadcastInstruction>(&hlo);
172
-
173
- const HloParameterInstruction* parameter =
174
- Cast<HloParameterInstruction>(hlo.operand (0 ));
175
-
176
- // Support only one dim broadcast. Scalar parameters are handled elsewhere.
177
- if (broadcast->dimensions ().empty () ||
178
- parameter->shape ().dimensions_size () + 1 !=
179
- broadcast->shape ().dimensions_size ()) {
180
- return false ;
181
- }
182
-
183
- // It is enough to ensure that the broadcast does not preserve both last, and
184
- // first dimensions of the parameter at the same time. Otherwise the broadcast
185
- // is the unsupported case #4.
186
- //
187
- // Preserve the first dim:
188
- // p = f32[a,b] parameter(0)
189
- // b1 = f32[a,b,c] broadcast(p), dimensions={0,1}
190
- bool preserve_first_dim = broadcast->dimensions ().front () == 0 ;
191
- // Preserve the last dim:
192
- // p = f32[a,b] parameter(0)
193
- // b1 = f32[c,a,b] broadcast(p), dimensions={1,2}
194
- bool preserve_last_dim = broadcast->dimensions ().back () ==
195
- broadcast->shape ().dimensions_size () - 1 ;
196
- // We do not want to preserve both first and last dim, as it means the
197
- // broadcast is not expanding on outermost dims.
198
- return !(preserve_first_dim && preserve_last_dim);
199
- }
200
-
201
- bool IsBroadcastOfAScalar (const HloInstruction& hlo) {
202
- CHECK_EQ (hlo.opcode (), HloOpcode::kBroadcast )
203
- << " Expected broadcast " << hlo.ToShortString ();
204
- return ShapeUtil::IsScalar (hlo.operand (0 )->shape ());
205
- }
206
-
207
- bool IsSingleRowParameterBroadcast (const HloInstruction& hlo) {
208
- CHECK_EQ (hlo.opcode (), HloOpcode::kBroadcast )
209
- << " Expected broadcast " << hlo.ToShortString ();
210
- CHECK_EQ (hlo.operand (0 )->opcode (), HloOpcode::kParameter )
211
- << " Expected parameter " << hlo.operand (0 )->ToShortString ();
212
-
213
- const HloBroadcastInstruction* broadcast =
214
- Cast<HloBroadcastInstruction>(&hlo);
215
- const HloParameterInstruction* parameter =
216
- Cast<HloParameterInstruction>(hlo.operand (0 ));
217
-
218
- if (parameter->shape ().dimensions_size () != 1 ) {
219
- return false ;
220
- }
221
- return broadcast->dimensions ()[0 ] == broadcast->shape ().dimensions_size () - 1 ;
222
- }
223
-
224
- bool IsSupportedBroadcastOfParameter (const HloInstruction& hlo) {
225
- return IsBroadcastOfParameter (hlo) &&
226
- (IsBatchOrReductionDimBroadcast (hlo) || IsBroadcastOfAScalar (hlo) ||
227
- IsSingleRowParameterBroadcast (hlo));
228
- }
229
-
230
127
// Chooses which operand to use for fusion processing. Taking in a unary or
231
128
// binary instruction, returns the first non-splat operand. If none is
232
129
// present, returns any operand.
@@ -238,7 +135,7 @@ HloInstruction* ChooseOperandForFusionProcessing(HloInstruction* instr) {
238
135
// broadcast of any op.
239
136
if (instr->operand_count () > 1 &&
240
137
(IsBroadcastOfScalarConstant (*instr->operand (0 )) ||
241
- IsSupportedBroadcastOfParameter (*instr->operand (0 )))) {
138
+ IsBroadcastOfParameter (*instr->operand (0 )))) {
242
139
return instr->mutable_operand (1 );
243
140
}
244
141
return instr->mutable_operand (0 );
@@ -284,9 +181,9 @@ bool IsTriviallyFusible(HloInstruction* instr,
284
181
// TODO(b/326217416): Extend the broadcast of splat constants/parameters to
285
182
// a broadcast of any op.
286
183
if ((IsBroadcastOfScalarConstant (*operand_0) ||
287
- IsSupportedBroadcastOfParameter (*operand_0)) ^
184
+ IsBroadcastOfParameter (*operand_0)) ^
288
185
(IsBroadcastOfScalarConstant (*operand_1) ||
289
- IsSupportedBroadcastOfParameter (*operand_1))) {
186
+ IsBroadcastOfParameter (*operand_1))) {
290
187
return static_cast <bool >(
291
188
IsTritonSupportedInstruction (*instr, gpu_version));
292
189
}
0 commit comments