forked from halide/Halide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlesson_13_tuples.cpp
300 lines (252 loc) · 11.1 KB
/
lesson_13_tuples.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
// Halide tutorial lesson 13: Tuples
// This lesson describes how to write Funcs that evaluate to multiple
// values.
// On linux, you can compile and run it like so:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_13 -std=c++17
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_13
// On os x:
// g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_13 -std=c++17
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_13
// If you have the entire Halide source tree, you can also build it by
// running:
// make tutorial_lesson_13_tuples
// in a shell with the current directory at the top of the halide
// source tree.
#include "Halide.h"
#include <algorithm>
#include <stdio.h>
using namespace Halide;
int main(int argc, char **argv) {
// So far Funcs (such as the one below) have evaluated to a single
// scalar value for each point in their domain.
Func single_valued;
Var x, y;
single_valued(x, y) = x + y;
// One way to write a Func that returns a collection of values is
// to add an additional dimension that indexes that
// collection. This is how we typically deal with color. For
// example, the Func below represents a collection of three values
// for every x, y coordinate indexed by c.
Func color_image;
Var c;
color_image(x, y, c) = select(c == 0, 245, // Red value
c == 1, 42, // Green value
132); // Blue value
// Since this pattern appears quite often, Halide provides a
// syntatic sugar to write the code above as the following,
// using the "mux" function.
// color_image(x, y, c) = mux(c, {245, 42, 132});
// This method is often convenient because it makes it easy to
// operate on this Func in a way that treats each item in the
// collection equally:
Func brighter;
brighter(x, y, c) = color_image(x, y, c) + 10;
// However this method is also inconvenient for three reasons.
//
// 1) Funcs are defined over an infinite domain, so users of this
// Func can for example access color_image(x, y, -17), which is
// not a meaningful value and is probably indicative of a bug.
//
// 2) It requires a select, which can impact performance if not
// bounded and unrolled:
// brighter.bound(c, 0, 3).unroll(c);
//
// 3) With this method, all values in the collection must have the
// same type. While the above two issues are merely inconvenient,
// this one is a hard limitation that makes it impossible to
// express certain things in this way.
// It is also possible to represent a collection of values as a
// collection of Funcs:
Func func_array[3];
func_array[0](x, y) = x + y;
func_array[1](x, y) = sin(x);
func_array[2](x, y) = cos(y);
// This method avoids the three problems above, but introduces a
// new annoyance. Because these are separate Funcs, it is
// difficult to schedule them so that they are all computed
// together inside a single loop over x, y.
// A third alternative is to define a Func as evaluating to a
// Tuple instead of an Expr. A Tuple is a fixed-size collection of
// Exprs. Each Expr in a Tuple may have a different type. The
// following function evaluates to an integer value (x+y), and a
// floating point value (sin(x*y)).
Func multi_valued;
multi_valued(x, y) = Tuple(x + y, sin(x * y));
// Realizing a tuple-valued Func returns a collection of
// Buffers. We call this a Realization. It's equivalent to a
// std::vector of Buffer objects:
{
Realization r = multi_valued.realize({80, 60});
assert(r.size() == 2);
Buffer<int> im0 = r[0];
Buffer<float> im1 = r[1];
assert(im0(30, 40) == 30 + 40);
assert(im1(30, 40) == sinf(30 * 40));
}
// All Tuple elements are evaluated together over the same domain
// in the same loop nest, but stored in distinct allocations. The
// equivalent C++ code to the above is:
{
int multi_valued_0[80 * 60];
float multi_valued_1[80 * 60];
for (int y = 0; y < 80; y++) {
for (int x = 0; x < 60; x++) {
multi_valued_0[x + 60 * y] = x + y;
multi_valued_1[x + 60 * y] = sinf(x * y);
}
}
}
// When compiling ahead-of-time, a Tuple-valued Func evaluates
// into multiple distinct output halide_buffer_t structs. These appear in
// order at the end of the function signature:
// int multi_valued(...input buffers and params...,
// halide_buffer_t *output_1, halide_buffer_t *output_2);
// You can construct a Tuple by passing multiple Exprs to the
// Tuple constructor as we did above. Perhaps more elegantly, you
// can also take advantage of initializer lists and just
// enclose your Exprs in braces:
Func multi_valued_2;
multi_valued_2(x, y) = {x + y, sin(x * y)};
// Calls to a multi-valued Func cannot be treated as Exprs. The
// following is a syntax error:
// Func consumer;
// consumer(x, y) = multi_valued_2(x, y) + 10;
// Instead you must index a Tuple with square brackets to retrieve
// the individual Exprs:
Expr integer_part = multi_valued_2(x, y)[0];
Expr floating_part = multi_valued_2(x, y)[1];
Func consumer;
consumer(x, y) = {integer_part + 10, floating_part + 10.0f};
// Tuple reductions.
{
// Tuples are particularly useful in reductions, as they allow
// the reduction to maintain complex state as it walks along
// its domain. The simplest example is an argmax.
// First we create a Buffer to take the argmax over.
Func input_func;
input_func(x) = sin(x);
Buffer<float> input = input_func.realize({100});
// Then we define a 2-valued Tuple which tracks the index of
// the maximum value and the value itself.
Func arg_max;
// Pure definition.
arg_max() = {0, input(0)};
// Update definition.
RDom r(1, 99);
Expr old_index = arg_max()[0];
Expr old_max = arg_max()[1];
Expr new_index = select(old_max < input(r), r, old_index);
Expr new_max = max(input(r), old_max);
arg_max() = {new_index, new_max};
// The equivalent C++ is:
int arg_max_0 = 0;
float arg_max_1 = input(0);
for (int r = 1; r < 100; r++) {
int old_index = arg_max_0;
float old_max = arg_max_1;
int new_index = old_max < input(r) ? r : old_index;
float new_max = std::max(input(r), old_max);
// In a tuple update definition, all loads and computation
// are done before any stores, so that all Tuple elements
// are updated atomically with respect to recursive calls
// to the same Func.
arg_max_0 = new_index;
arg_max_1 = new_max;
}
// Let's verify that the Halide and C++ found the same maximum
// value and index.
{
Realization r = arg_max.realize();
Buffer<int> r0 = r[0];
Buffer<float> r1 = r[1];
assert(arg_max_0 == r0(0));
assert(arg_max_1 == r1(0));
}
// Halide provides argmax and argmin as built-in reductions
// similar to sum, product, maximum, and minimum. They return
// a Tuple consisting of the point in the reduction domain
// corresponding to that value, and the value itself. In the
// case of ties they return the first value found. We'll use
// one of these in the following section.
}
// Tuples for user-defined types.
{
// Tuples can also be a convenient way to represent compound
// objects such as complex numbers. Defining an object that
// can be converted to and from a Tuple is one way to extend
// Halide's type system with user-defined types.
struct Complex {
Expr real, imag;
// Construct from a Tuple
Complex(Tuple t)
: real(t[0]), imag(t[1]) {
}
// Construct from a pair of Exprs
Complex(Expr r, Expr i)
: real(r), imag(i) {
}
// Construct from a call to a Func by treating it as a Tuple
Complex(FuncRef t)
: Complex(Tuple(t)) {
}
// Convert to a Tuple
operator Tuple() const {
return {real, imag};
}
// Complex addition
Complex operator+(const Complex &other) const {
return {real + other.real, imag + other.imag};
}
// Complex multiplication
Complex operator*(const Complex &other) const {
return {real * other.real - imag * other.imag,
real * other.imag + imag * other.real};
}
// Complex magnitude, squared for efficiency
Expr magnitude_squared() const {
return real * real + imag * imag;
}
// Other complex operators would go here. The above are
// sufficient for this example.
};
// Let's use the Complex struct to compute a Mandelbrot set.
Func mandelbrot;
// The initial complex value corresponding to an x, y coordinate
// in our Func.
Complex initial(x / 15.0f - 2.5f, y / 6.0f - 2.0f);
// Pure definition.
Var t;
mandelbrot(x, y, t) = Complex(0.0f, 0.0f);
// We'll use an update definition to take 12 steps.
RDom r(1, 12);
Complex current = mandelbrot(x, y, r - 1);
// The following line uses the complex multiplication and
// addition we defined above.
mandelbrot(x, y, r) = current * current + initial;
// We'll use another tuple reduction to compute the iteration
// number where the value first escapes a circle of radius 4.
// This can be expressed as an argmin of a boolean - we want
// the index of the first time the given boolean expression is
// false (we consider false to be less than true). The argmax
// would return the index of the first time the expression is
// true.
Expr escape_condition = Complex(mandelbrot(x, y, r)).magnitude_squared() < 16.0f;
Tuple first_escape = argmin(escape_condition);
// We only want the index, not the value, but argmin returns
// both, so we'll index the argmin Tuple expression using
// square brackets to get the Expr representing the index.
Func escape;
escape(x, y) = first_escape[0];
// Realize the pipeline and print the result as ascii art.
Buffer<int> result = escape.realize({61, 25});
const char *code = " .:-~*={}&%#@";
for (int y = 0; y < result.height(); y++) {
for (int x = 0; x < result.width(); x++) {
printf("%c", code[result(x, y)]);
}
printf("\n");
}
}
printf("Success!\n");
return 0;
}