Skip to content

Commit 773ea41

Browse files
authoredMar 1, 2024
[CINN] Add IntrinsicOps into ir_codes_collector (#60556) (#62245)
This PR fixed a bug of running Resnet PaddleClas. The bug is due to vectorize introduce an intrinsic GetAddr and we didn't collect the tensor of GetAddr in ir_node_collector, this would caused tensor alias won't create in cuda code. TODO: we may modify IntrinsicOp in the near future
1 parent 521dc70 commit 773ea41

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed
 

‎paddle/cinn/ir/ir_base.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,23 @@ class Dim;
110110
macro__(Product) \
111111
macro__(Sum) \
112112
macro__(PrimitiveNode) \
113-
macro__(IntrinsicOp) \
114113
macro__(_BufferRange_) \
115114
macro__(ScheduleBlock) \
116115
macro__(ScheduleBlockRealize) \
117116
macro__(_Dim_) \
118117

118+
#define NODETY_CONTROL_OP_FOR_INTRINSIC(macro__) \
119+
macro__(IntrinsicOp) \
119120

120121
#define NODETY_FORALL(__m) \
121122
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
122123
NODETY_OP_FOR_EACH(__m) \
124+
NODETY_CONTROL_OP_FOR_INTRINSIC(__m) \
125+
NODETY_CONTROL_OP_FOR_EACH(__m)
126+
127+
#define NODETY_FORALL_EXCEPT_INTRINSIC(__m) \
128+
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
129+
NODETY_OP_FOR_EACH(__m) \
123130
NODETY_CONTROL_OP_FOR_EACH(__m)
124131
// clang-format on
125132

‎paddle/cinn/ir/utils/ir_nodes_collector.cc

+66-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
1616
#include <glog/logging.h>
1717

18+
#include "paddle/cinn/ir/intrinsic_ops.h"
19+
#include "paddle/cinn/ir/ir.h"
1820
#include "paddle/cinn/ir/ir_mutator.h"
1921
#include "paddle/cinn/ir/ir_printer.h"
2022

@@ -71,8 +73,71 @@ struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
7173
} \
7274
}
7375

74-
NODETY_FORALL(__m)
76+
NODETY_FORALL_EXCEPT_INTRINSIC(__m)
7577
#undef __m
78+
79+
void Visit(const ir::IntrinsicOp* op) {
80+
switch (op->getKind()) {
81+
#define __(x) \
82+
case ir::IntrinsicKind::k##x: \
83+
Visit(llvm::dyn_cast<ir::intrinsics::x>(op)); \
84+
break;
85+
86+
INTRINSIC_KIND_FOR_EACH(__)
87+
#undef __
88+
}
89+
}
90+
91+
void Visit(const ir::intrinsics::GetAddr* x) {
92+
if (x->data.defined()) {
93+
Visit(&(x->data));
94+
}
95+
}
96+
97+
void Visit(const ir::intrinsics::BufferGetDataHandle* x) {
98+
if (x->buffer.defined()) {
99+
Visit(&(x->buffer));
100+
}
101+
}
102+
103+
void Visit(const ir::intrinsics::BufferGetDataConstHandle* x) {
104+
if (x->buffer.defined()) {
105+
Visit(&(x->buffer));
106+
}
107+
}
108+
109+
void Visit(const ir::intrinsics::PodValueToX* x) {
110+
if (x->pod_value_ptr.defined()) {
111+
Visit(&(x->pod_value_ptr));
112+
}
113+
}
114+
115+
void Visit(const ir::intrinsics::BufferCreate* x) {
116+
if (x->buffer.defined()) {
117+
Visit(&(x->buffer));
118+
}
119+
}
120+
121+
void Visit(const ir::intrinsics::ArgsConstruct* x) {
122+
if (x->var.defined()) {
123+
Expr convert = Expr(x->var);
124+
Visit(&convert);
125+
}
126+
for (int i = 0; i < x->args.size(); ++i) {
127+
if (x->args[i].defined()) {
128+
Visit(&(x->args[i]));
129+
}
130+
}
131+
}
132+
133+
void Visit(const ir::intrinsics::BuiltinIntrin* x) {
134+
for (int i = 0; i < x->args.size(); ++i) {
135+
if (x->args[i].defined()) {
136+
Visit(&(x->args[i]));
137+
}
138+
}
139+
}
140+
76141
std::set<void*> visited_;
77142
};
78143

0 commit comments

Comments
 (0)
Please sign in to comment.