forked from halide/Halide
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDebugToFile.cpp
More file actions
156 lines (129 loc) · 5.18 KB
/
DebugToFile.cpp
File metadata and controls
156 lines (129 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#include <map>
#include <sstream>
#include <vector>
#include "DebugToFile.h"
#include "IRMutator.h"
#include "IROperator.h"
namespace Halide {
namespace Internal {
using std::map;
using std::ostringstream;
using std::string;
using std::vector;
class DebugToFile : public IRMutator {
const map<string, Function> &env;
using IRMutator::visit;
Stmt visit(const Realize *op) override {
map<string, Function>::const_iterator iter = env.find(op->name);
if (iter != env.end() && !iter->second.debug_file().empty()) {
Function f = iter->second;
vector<Expr> args;
user_assert(op->types.size() == 1)
<< "debug_to_file doesn't handle functions with multiple values yet\n";
// The name of the file
args.push_back(f.debug_file());
// Inject loads to the corners of the function so that any
// passes doing further analysis of buffer use understand
// what we're doing (e.g. so we trigger a copy-back from a
// device pointer).
Expr num_elements = 1;
for (size_t i = 0; i < op->bounds.size(); i++) {
num_elements *= op->bounds[i].extent;
}
int type_code = 0;
Type t = op->types[0];
if (t == Float(32)) {
type_code = 0;
} else if (t == Float(64)) {
type_code = 1;
} else if (t == UInt(8) || t == UInt(1)) {
type_code = 2;
} else if (t == Int(8)) {
type_code = 3;
} else if (t == UInt(16)) {
type_code = 4;
} else if (t == Int(16)) {
type_code = 5;
} else if (t == UInt(32)) {
type_code = 6;
} else if (t == Int(32)) {
type_code = 7;
} else if (t == UInt(64)) {
type_code = 8;
} else if (t == Int(64)) {
type_code = 9;
} else {
user_error << "Type " << t << " not supported for debug_to_file\n";
}
args.push_back(type_code);
Expr buf = Variable::make(Handle(), f.name() + ".buffer");
args.push_back(buf);
Expr call = Call::make(Int(32), Call::debug_to_file, args, Call::Intrinsic);
string call_result_name = unique_name("debug_to_file_result");
Expr call_result_var = Variable::make(Int(32), call_result_name);
Stmt body = AssertStmt::make(call_result_var == 0,
Call::make(Int(32), "halide_error_debug_to_file_failed",
{f.name(), f.debug_file(), call_result_var},
Call::Extern));
body = LetStmt::make(call_result_name, call, body);
body = Block::make(mutate(op->body), body);
return Realize::make(op->name, op->types, op->memory_type, op->bounds, op->condition, body);
} else {
return IRMutator::visit(op);
}
}
public:
DebugToFile(const map<string, Function> &e) : env(e) {}
};
class RemoveDummyRealizations : public IRMutator {
const vector<Function> &outputs;
using IRMutator::visit;
Stmt visit(const Realize *op) override {
for (Function f : outputs) {
if (op->name == f.name()) {
return mutate(op->body);
}
}
return IRMutator::visit(op);
}
public:
RemoveDummyRealizations(const vector<Function> &o) : outputs(o) {}
};
class AddDummyRealizations : public IRMutator {
const vector<Function> &outputs;
using IRMutator::visit;
Stmt visit(const ProducerConsumer *op) override {
Stmt s = IRMutator::visit(op);
for (Function out : outputs) {
if (op->name == out.name()) {
std::vector<Range> output_bounds;
for (int i = 0; i < out.dimensions(); i++) {
string dim = std::to_string(i);
Expr min = Variable::make(Int(32), out.name() + ".min." + dim);
Expr extent = Variable::make(Int(32), out.name() + ".extent." + dim);
output_bounds.push_back(Range(min, extent));
}
return Realize::make(out.name(),
out.output_types(),
MemoryType::Auto,
output_bounds,
const_true(),
s);
}
}
return s;
}
public:
AddDummyRealizations(const vector<Function> &o) : outputs(o) {}
};
Stmt debug_to_file(Stmt s, const vector<Function> &outputs, const map<string, Function> &env) {
// Temporarily wrap the produce nodes for the output functions in
// realize nodes so that we know when to write the debug outputs.
s = AddDummyRealizations(outputs).mutate(s);
s = DebugToFile(env).mutate(s);
// Remove the realize node we wrapped around the output
s = RemoveDummyRealizations(outputs).mutate(s);
return s;
}
} // namespace Internal
} // namespace Halide