Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cleanup] Use HloPredicateIs(Not)Op #19732

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[Cleanup] Use HloPredicateIs(Not)Op
PiperOrigin-RevId: 703499929
  • Loading branch information
frgossen authored and Google-ML-Automation committed Dec 6, 2024
commit 41bc5c538d47e7a4584bfd641b7961c99b5656ac
2 changes: 1 addition & 1 deletion xla/service/gpu/transforms/dot_normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License.
namespace xla::gpu {

bool DotNormalizer::InstructionMatchesPattern(HloInstruction* instruction) {
if (instruction->opcode() != HloOpcode::kDot) {
if (HloPredicateIsNotOp<HloOpcode::kDot>(instruction)) {
return false;
}
return instruction->dot_dimension_numbers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ ENTRY main {

HloInstruction* while_instruction;
for (auto instr : module->entry_computation()->instructions()) {
if (instr->opcode() == HloOpcode::kWhile) {
if (HloPredicateIsOp<HloOpcode::kWhile>(instr)) {
while_instruction = instr;
}
}
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/transforms/gpusolver_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ absl::StatusOr<HloInstruction*> CreateCholesky(GpuSolverContext* context,
// Tries to rewrite a single convolution into a call to cudnn.
absl::StatusOr<bool> RunOnInstruction(GpuSolverContext* context,
HloInstruction* instruction) {
if (instruction->opcode() != HloOpcode::kCholesky) {
if (HloPredicateIsNotOp<HloOpcode::kCholesky>(instruction)) {
return false;
}

Expand All @@ -164,7 +164,7 @@ absl::StatusOr<bool> GpusolverRewriter::RunOnComputation(
HloComputation* computation) {
std::vector<HloInstruction*> cusolver_calls;
for (auto* hlo : computation->instructions()) {
if (hlo->opcode() == HloOpcode::kCholesky) {
if (HloPredicateIsOp<HloOpcode::kCholesky>(hlo)) {
cusolver_calls.push_back(hlo);
}
}
Expand Down
12 changes: 6 additions & 6 deletions xla/service/gpu/transforms/horizontal_loop_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr,
: &instr;

// Too large shapes are not easily profitable.
if (root->opcode() == HloOpcode::kTuple) {
if (HloPredicateIsOp<HloOpcode::kTuple>(root)) {
// Since all output shapes are the same, use the first shape as the
// representative.
root = root->operand(0);
Expand Down Expand Up @@ -264,8 +264,7 @@ bool AnyOperandIsSharedAmongFusions(
}

HloInstruction* LatestNonTrivialAncestor(HloInstruction* hlo) {
if (hlo->opcode() == HloOpcode::kGetTupleElement ||
hlo->opcode() == HloOpcode::kBitcast) {
if (HloPredicateIsOp<HloOpcode::kGetTupleElement, HloOpcode::kBitcast>(hlo)) {
return LatestNonTrivialAncestor(hlo->mutable_operand(0));
}
return hlo;
Expand Down Expand Up @@ -441,7 +440,7 @@ absl::StatusOr<bool> HorizontalLoopFusionImpl::FuseConsumerOperands(
std::vector<HloInstruction*> fusion_instrs;
for (HloInstruction* instr : fusibles) {
VLOG(2) << "next candidate: " << instr->ToString();
if (instr->opcode() == HloOpcode::kFusion) {
if (HloPredicateIsOp<HloOpcode::kFusion>(instr)) {
fusion_instrs.push_back(instr);
} else {
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -502,8 +501,9 @@ absl::Status HorizontalLoopFusionImpl::CreateFusedComputation(
->fused_instructions_computation()
->MakeInstructionPostOrder();
for (HloInstruction* old_instr : def_to_use_order) {
if (old_instr->opcode() == HloOpcode::kParameter ||
(sliced_input_fusion && old_instr->opcode() == HloOpcode::kTuple &&
if (HloPredicateIsOp<HloOpcode::kParameter>(old_instr) ||
(sliced_input_fusion &&
HloPredicateIsOp<HloOpcode::kTuple>(old_instr) &&
old_instr == fused_fusion_instrs[i]->fused_expression_root())) {
// Parameters have been created, and we don't need tuples from
// multi-output fusions, as we will directly reference the tuple
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/transforms/sanitize_constant_names.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ absl::StatusOr<bool> SanitizeConstantNames::Run(
// Collect the names used for the non-constant HLO instructions.+
for (HloComputation* computation : module->computations(execution_threads)) {
for (HloInstruction* instr : computation->instructions()) {
if (instr->opcode() == HloOpcode::kConstant) {
if (HloPredicateIsOp<HloOpcode::kConstant>(instr)) {
continue;
}

Expand All @@ -55,7 +55,7 @@ absl::StatusOr<bool> SanitizeConstantNames::Run(
// even though the non-constant HLO comes after in the HLO module.
for (HloComputation* computation : module->computations(execution_threads)) {
for (HloInstruction* instr : computation->instructions()) {
if (instr->opcode() != HloOpcode::kConstant) {
if (HloPredicateIsNotOp<HloOpcode::kConstant>(instr)) {
continue;
}
std::string sanitized_name = llvm_ir::SanitizeConstantName(*instr);
Expand Down
Loading