|
15 | 15 | #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
|
16 | 16 | #include <glog/logging.h>
|
17 | 17 |
|
| 18 | +#include "paddle/cinn/ir/intrinsic_ops.h" |
| 19 | +#include "paddle/cinn/ir/ir.h" |
18 | 20 | #include "paddle/cinn/ir/ir_mutator.h"
|
19 | 21 | #include "paddle/cinn/ir/ir_printer.h"
|
20 | 22 |
|
@@ -71,8 +73,71 @@ struct IrNodesCollector : public IRVisitorRequireReImpl<void> {
|
71 | 73 | } \
|
72 | 74 | }
|
73 | 75 |
|
74 |
| - NODETY_FORALL(__m) |
| 76 | + NODETY_FORALL_EXCEPT_INTRINSIC(__m) |
75 | 77 | #undef __m
|
| 78 | + |
| 79 | + void Visit(const ir::IntrinsicOp* op) { |
| 80 | + switch (op->getKind()) { |
| 81 | +#define __(x) \ |
| 82 | + case ir::IntrinsicKind::k##x: \ |
| 83 | + Visit(llvm::dyn_cast<ir::intrinsics::x>(op)); \ |
| 84 | + break; |
| 85 | + |
| 86 | + INTRINSIC_KIND_FOR_EACH(__) |
| 87 | +#undef __ |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + void Visit(const ir::intrinsics::GetAddr* x) { |
| 92 | + if (x->data.defined()) { |
| 93 | + Visit(&(x->data)); |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + void Visit(const ir::intrinsics::BufferGetDataHandle* x) { |
| 98 | + if (x->buffer.defined()) { |
| 99 | + Visit(&(x->buffer)); |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + void Visit(const ir::intrinsics::BufferGetDataConstHandle* x) { |
| 104 | + if (x->buffer.defined()) { |
| 105 | + Visit(&(x->buffer)); |
| 106 | + } |
| 107 | + } |
| 108 | + |
| 109 | + void Visit(const ir::intrinsics::PodValueToX* x) { |
| 110 | + if (x->pod_value_ptr.defined()) { |
| 111 | + Visit(&(x->pod_value_ptr)); |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + void Visit(const ir::intrinsics::BufferCreate* x) { |
| 116 | + if (x->buffer.defined()) { |
| 117 | + Visit(&(x->buffer)); |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + void Visit(const ir::intrinsics::ArgsConstruct* x) { |
| 122 | + if (x->var.defined()) { |
| 123 | + Expr convert = Expr(x->var); |
| 124 | + Visit(&convert); |
| 125 | + } |
| 126 | + for (int i = 0; i < x->args.size(); ++i) { |
| 127 | + if (x->args[i].defined()) { |
| 128 | + Visit(&(x->args[i])); |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + void Visit(const ir::intrinsics::BuiltinIntrin* x) { |
| 134 | + for (int i = 0; i < x->args.size(); ++i) { |
| 135 | + if (x->args[i].defined()) { |
| 136 | + Visit(&(x->args[i])); |
| 137 | + } |
| 138 | + } |
| 139 | + } |
| 140 | + |
76 | 141 | std::set<void*> visited_;
|
77 | 142 | };
|
78 | 143 |
|
|
0 commit comments