forked from halide/Halide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAutoSchedule.cpp
3530 lines (3075 loc) · 142 KB
/
AutoSchedule.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <algorithm>
#include <regex>
#include "AutoSchedule.h"
#include "AutoScheduleUtils.h"
#include "ExprUsesVar.h"
#include "FindCalls.h"
#include "Func.h"
#include "Inline.h"
#include "IREquality.h"
#include "ParallelRVar.h"
#include "RealizationOrder.h"
#include "RegionCosts.h"
#include "Scope.h"
#include "Simplify.h"
#include "Util.h"
namespace Halide {
namespace Internal {
using std::string;
using std::vector;
using std::map;
using std::set;
using std::deque;
using std::pair;
using std::make_pair;
namespace {
int string_to_int(const std::string &s) {
std::istringstream iss(s);
int i;
iss >> i;
user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << s;
return i;
}
// Return true if any of the box dimension is unbounded.
bool is_box_unbounded(const Box &b) {
for (size_t i = 0; i < b.size(); i++) {
if (!b[i].is_bounded()) {
return true;
}
}
return false;
}
// Helper function to simplify the upper and lower bounds of each dimension of a box.
void simplify_box(Box &b) {
for (size_t i = 0; i < b.size(); i++) {
b[i].min = simplify(b[i].min);
b[i].max = simplify(b[i].max);
}
}
// Helper function to merge the partial region map into the result region map.
void merge_regions(map<string, Box> &result, const map<string, Box> &partial) {
// Merge regions from 'partial' with an existing region in 'result'.
for (const auto ® : partial) {
auto iter = result.find(reg.first);
if (iter == result.end()) {
result.emplace(reg.first, reg.second);
} else {
merge_boxes(iter->second, reg.second);
}
}
}
// Replace all occurrences of non-alphanumeric chars in 'name' with '_'.
string get_sanitized_name(string name) {
if (isdigit(name[0])) {
name = "_" + name;
}
for (size_t i = 0; i < name.size(); ++i) {
if (!isalnum(name[i])) {
name[i] = '_';
}
}
return name;
}
// Representation of a function stage in the pipeline.
struct FStage {
Function func;
uint32_t stage_num;
FStage(Function func, uint32_t stage_num) : func(func), stage_num(stage_num) {}
bool operator==(const FStage &other_stage) const {
return (func.name() == other_stage.func.name()) &&
(stage_num == other_stage.stage_num);
}
bool operator<(const FStage &other_stage) const {
return func.name() < other_stage.func.name() ||
((func.name() == other_stage.func.name()) &&
(stage_num < other_stage.stage_num));
}
friend std::ostream& operator<<(std::ostream &stream, const FStage &s) {
if (s.stage_num == 0) {
stream << s.func.name();
} else {
stream << s.func.name() << ".update(" << (s.stage_num - 1) << ")";
}
return stream;
}
};
// Check if all the pipeline outputs have estimates specified
// on each of their dimensions; otherwise, throw an assertion.
void check_estimates_on_outputs(const vector<Function> &outputs) {
for (const auto &out : outputs) {
const vector<Bound> &estimates = out.schedule().estimates();
// Check if the estimate for each dimension of the output is available
// and is an integer. If there are duplicates for the estimate of a
// dimension, we only check the last defined estimate (which min and
// extent values are defined) since it is the one that would be
// eventually used.
Bound est;
for (const auto &arg : out.args()) {
bool found = false;
for (int i = (int)estimates.size() - 1; i >= 0; --i) {
if ((estimates[i].var == arg) && estimates[i].min.defined() &&
estimates[i].extent.defined()) {
found = true;
est = estimates[i];
break;
}
}
user_assert(found && est.min.type().is_int() && est.extent.type().is_int())
<< "Please provide a valid estimate for dimension "
<< est.var << " of output \"" << out.name() << "\"\n";
}
}
}
struct DependenceAnalysis {
// Map containing all the functions in the pipeline.
const map<string, Function> &env;
const vector<string> ℴ
const FuncValueBounds &func_val_bounds;
struct RegionsRequiredQuery {
string f;
int stage;
set<string> prods;
bool only_regions_computed;
RegionsRequiredQuery(const string &f, int stage, const set<string> &prods,
bool only_regions_computed)
: f(f), stage(stage), prods(prods),
only_regions_computed(only_regions_computed) {}
bool operator==(const RegionsRequiredQuery &other) const {
return (f == other.f) && (stage == other.stage) && (prods == other.prods) &&
(only_regions_computed == other.only_regions_computed);
}
bool operator<(const RegionsRequiredQuery &other) const {
if (f < other.f) {
return true;
} else if (f > other.f) {
return false;
}
if (stage < other.stage) {
return true;
} else if (stage > other.stage) {
return false;
}
if (only_regions_computed < other.only_regions_computed) {
return true;
} else if (only_regions_computed > other.only_regions_computed) {
return false;
}
return prods < other.prods;
}
};
struct RegionsRequired {
DimBounds bounds;
// Regions required to compute 'bounds' given a particular
// RegionsRequiredQuery.
map<string, Box> regions;
RegionsRequired(const DimBounds &b, const map<string, Box> &r)
: bounds(b), regions(r) {}
};
// Cache for bounds queries (bound queries with the same parameters are
// common during the grouping process).
map<RegionsRequiredQuery, vector<RegionsRequired>> regions_required_cache;
DependenceAnalysis(const map<string, Function> &env, const vector<string> &order,
const FuncValueBounds &func_val_bounds)
: env(env), order(order), func_val_bounds(func_val_bounds) {}
// Return the regions of the producers ('prods') required to compute the region
// of the function stage ('f', 'stage_num') specified by 'bounds'. When
// 'only_regions_computed' is set to true, this only returns the computed
// regions and not the total allocated regions.
map<string, Box> regions_required(Function f, int stage_num,
const DimBounds &bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates);
// Return the regions of the producers ('prods') required to compute the region
// of the function specified by 'pure_bounds'. When 'only_regions_computed'
// is set to true, this only returns the computed regions and not the total
// allocated regions.
map<string, Box> regions_required(Function f,
const DimBounds &pure_bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates);
// Return redundantly computed regions of producers ('prods') while computing
// a region of the function stage ('f', 'stage_num') specified by 'bounds'.
// 'var' is the dimension along which redundant computation is accounted for.
// When 'only_regions_computed' is set to true, this only returns the computed
// regions and not the total allocated regions. When 'only_regions_computed'
// is set to true, this only returns the computed regions and not the total
// allocated regions.
map<string, Box> redundant_regions(Function f, int stage_num, string var,
const DimBounds &bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates);
// Return overlapping regions of producers ('prods') while computing a function
// stage along each of the dimensions.
vector<map<string, Box>>
overlap_regions(Function f, int stage_num, const DimBounds &bounds,
const set<string> &prods, bool only_regions_computed,
const Scope<Interval> *input_estimates);
};
// Return the regions of the producers ('prods') required to compute the region
// of the function specified by 'pure_bounds'.
map<string, Box>
DependenceAnalysis::regions_required(Function f, const DimBounds &pure_bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates) {
// Find the regions required for each stage and merge them.
map<string, Box> regions;
int num_stages = f.updates().size() + 1;
for (int s = 0; s < num_stages; s++) {
DimBounds bounds = get_stage_bounds(f, s, pure_bounds);
map<string, Box> stage_regions =
regions_required(f, s, bounds, prods, only_regions_computed, input_estimates);
merge_regions(regions, stage_regions);
}
return regions;
}
struct StageBounds {
FStage f_stage;
DimBounds bounds;
StageBounds(const FStage &fs, const DimBounds &b) : f_stage(fs), bounds(b) {}
StageBounds(Function func, uint32_t stage_num, const DimBounds &b) :
f_stage(FStage(func, stage_num)), bounds(b) {}
bool operator==(const StageBounds &other) const {
return (f_stage == other.f_stage) && (bounds == other.bounds);
}
bool operator<(const StageBounds &other) const {
return (f_stage < other.f_stage) ||
((f_stage == other.f_stage) && (bounds.size() < other.bounds.size()));
}
friend std::ostream& operator<<(std::ostream &stream, const StageBounds &s) {
stream << "Stage: " << s.f_stage << "\n";
stream << "Bounds:\n";
for (const auto &iter : s.bounds) {
stream << "\t" << iter.first << " -> [" << iter.second.min << ", " << iter.second.max << "]\n";
}
stream << "\n";
return stream;
}
};
// Helper function to queue regions that need to be traversed. 'fs_bounds' is
// the queue into which the regions specified by 'prod_func' and 'region'
// will be added.
void queue_func_regions(map<FStage, DimBounds> &fs_bounds,
const Function &prod_func, const Box ®ion,
const set<StageBounds>& visited) {
DimBounds prod_pure_bounds;
const vector<string> &args = prod_func.args();
internal_assert(region.size() == args.size());
// The region only specifies the extent of each dimension
// by position. Populating a map which is keyed by name.
for (size_t v = 0; v < args.size(); v++) {
prod_pure_bounds[args[v]] = region[v];
}
// Get the bounds of all stages in a function from the
// bounds on the pure dimensions.
vector<DimBounds> prod_bounds = get_stage_bounds(prod_func, prod_pure_bounds);
size_t num_stages = prod_func.updates().size() + 1;
internal_assert(prod_bounds.size() == num_stages);
// Add all stages of a function into the queue.
for (size_t prod_s = 0; prod_s < num_stages; prod_s++) {
StageBounds sb(prod_func, prod_s, prod_bounds[prod_s]);
if (visited.find(sb) == visited.end()) {
auto iter = fs_bounds.find(sb.f_stage);
if (iter == fs_bounds.end()) {
fs_bounds.emplace(sb.f_stage, sb.bounds);
} else {
for (const auto &b : sb.bounds) {
DimBounds &curr_bounds = iter->second;
auto b_iter = curr_bounds.find(b.first);
if (b_iter == curr_bounds.end()) {
curr_bounds.emplace(b.first, b.second);
} else {
if (b_iter->second.has_lower_bound() && b.second.has_lower_bound()) {
b_iter->second.min = simplify(Interval::make_min(b_iter->second.min, b.second.min));
} else {
b_iter->second.min = Interval::neg_inf;
}
if (b_iter->second.has_upper_bound() && b.second.has_upper_bound()) {
b_iter->second.max = simplify(Interval::make_max(b_iter->second.max, b.second.max));
} else {
b_iter->second.max = Interval::pos_inf;
}
}
}
}
}
}
}
// Helper function for merging 'curr_regions' to the global map of regions
// and adding them to the queue of regions that need to be traversed.
// 'prods' is the set of producer functions that are under consideration.
void merge_and_queue_regions(map<FStage, DimBounds> &fs_bounds,
map<string, Box> ®ions,
map<string, Box> &curr_regions,
const set<string> &prods,
const map<string, Function> &env,
bool only_regions_computed,
string curr_func_name,
const set<StageBounds>& visited) {
for (const auto ® : curr_regions) {
// Merge region with an existing region of a function in the
// global map. Do not merge the parent function itself to the region
// when querying only for the values computed.
if (!only_regions_computed || (only_regions_computed && (reg.first != curr_func_name))) {
auto iter = regions.find(reg.first);
if (iter == regions.end()) {
regions.emplace(reg.first, reg.second);
} else {
merge_boxes(iter->second, reg.second);
}
}
// Skip adding the current region into to the queue if the function
// is not in 'prods'.
if (prods.find(reg.first) == prods.end()) {
continue;
}
const auto &it = env.find(reg.first);
if ((it != env.end()) && (reg.first != curr_func_name)) {
// Add all stages of the function representing the
// region into the queue.
queue_func_regions(fs_bounds, it->second, reg.second, visited);
}
}
}
// Return the regions of the producers ('prods') required to compute the region
// of the function stage ('f', 'stage_num') specified by 'bounds'.
map<string, Box>
DependenceAnalysis::regions_required(Function f, int stage_num,
const DimBounds &bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates) {
// Iteratively compute the required regions by traversing the chain
// of dependencies.
// Check the cache if we've already computed this previously.
RegionsRequiredQuery query(f.name(), stage_num, prods, only_regions_computed);
const auto &iter = regions_required_cache.find(query);
if (iter != regions_required_cache.end()) {
const auto &it = std::find_if(iter->second.begin(), iter->second.end(),
[&bounds](const RegionsRequired &r) { return (r.bounds == bounds); });
if (it != iter->second.end()) {
internal_assert((iter->first == query) && (it->bounds == bounds));
return it->regions;
}
}
// Map of all the required regions.
map<string, Box> regions;
map<FStage, DimBounds> fs_bounds;
set<StageBounds> visited;
// Add the query function and its region to the queue.
fs_bounds.emplace(FStage(f, stage_num), bounds);
while (!fs_bounds.empty()) {
for (int i = order.size() - 1; i >= 0; --i) {
const Function &f = env.find(order[i])->second;
int num_stages = f.updates().size() + 1;
for (int stage_num = 0; stage_num < num_stages; ++stage_num) {
FStage s(f, stage_num);
const auto &iter = fs_bounds.find(s);
if (iter == fs_bounds.end()) {
continue;
}
DimBounds curr_bounds = iter->second;
visited.insert(StageBounds(s, curr_bounds));
Definition def = get_stage_definition(s.func, s.stage_num);
// Scope for containing all the estimates on parameters and intervals.
Scope<Interval> curr_scope;
curr_scope.set_containing_scope(input_estimates);
const vector<Dim> &dims = def.schedule().dims();
// Substitute parameter estimates into the bounds and add them to the
// current scope.
for (int d = 0; d < (int)dims.size() - 1; d++) {
string var_name = dims[d].var;
internal_assert(curr_bounds.find(var_name) != curr_bounds.end());
Expr lower = SubstituteVarEstimates().mutate(get_element(curr_bounds, dims[d].var).min);
Expr upper = SubstituteVarEstimates().mutate(get_element(curr_bounds, dims[d].var).max);
Interval simple_bounds = Interval(simplify(lower), simplify(upper));
curr_scope.push(var_name, simple_bounds);
}
// If the function has an extern definition, there is no visibility into
// the expression defining the function. So the regions required will be
// the entire domain of the inputs to the extern func. Use the estimates
// on the inputs to the extern function if available.
//
// TODO: Query the extern function for bounds of the functions which it
// it depends on. This can be done by calling the extern func in the
// bounds query mode.
if (s.func.has_extern_definition()) {
for (const ExternFuncArgument &arg : s.func.extern_arguments()) {
if (arg.is_func()) {
// If the argument is an entire function, the bounds of the
// function required are unknown. Create an infinite region
// of the correct dimension, update the region map, and
// add it to the queue.
string prod_name = Function(arg.func).name();
const Function &prod_func = get_element(env, prod_name);
map<string, Box> prod_reg;
const vector<string> &args = prod_func.args();
for (size_t v = 0; v < args.size(); v++) {
prod_reg[prod_name].push_back(Interval());
}
merge_and_queue_regions(fs_bounds, regions, prod_reg, prods, env,
only_regions_computed, s.func.name(), visited);
} else if (arg.is_expr()) {
// Find the boxes required for the expression and add the regions
// to the queue.
Expr subs_arg = SubstituteVarEstimates().mutate(arg.expr);
map<string, Box> arg_regions = boxes_required(subs_arg, curr_scope, func_val_bounds);
merge_and_queue_regions(fs_bounds, regions, arg_regions, prods, env,
only_regions_computed, s.func.name(), visited);
} else if (arg.is_image_param() || arg.is_buffer()) {
// If the argument is an image or a buffer, the required
// bounds are unknown. Create an infinite region of the
// correct dimension and update the region map.
Buffer<> buf;
if (arg.is_image_param()) {
buf = arg.image_param.buffer();
} else {
buf = arg.buffer;
}
map<string, Box> buf_reg;
for (int v = 0; v < buf.dimensions(); v++) {
buf_reg[buf.name()].push_back(Interval());
}
merge_regions(regions, buf_reg);
}
}
}
// Find the regions required for each value of the current function stage,
// update the region map, and add them to the queue.
for (const auto &val : def.values()) {
// Substitute the parameter estimates into the expression and get
// the regions required for the expression.
Expr subs_val = SubstituteVarEstimates().mutate(val);
map<string, Box> curr_regions = boxes_required(subs_val, curr_scope, func_val_bounds);
// Arguments to the definition may require regions of functions.
// For example, update definitions in histograms where the bin is
// based on the value of a function.
Box left_reg;
for (const Expr &arg : def.args()) {
Expr subs_arg = SubstituteVarEstimates().mutate(arg);
map<string, Box> arg_regions = boxes_required(subs_arg, curr_scope, func_val_bounds);
// Merge the regions with the regions found while looking at
// the values.
merge_regions(curr_regions, arg_regions);
Interval arg_bounds = bounds_of_expr_in_scope(arg, curr_scope, func_val_bounds);
left_reg.push_back(arg_bounds);
}
auto iter = curr_regions.find(s.func.name());
if (iter == curr_regions.end()) {
curr_regions.emplace(s.func.name(), left_reg);
} else {
merge_boxes(iter->second, left_reg);
}
// Update the region map, and add 'curr_regions' to the queue.
merge_and_queue_regions(fs_bounds, regions, curr_regions, prods, env,
only_regions_computed, s.func.name(), visited);
}
// Remove processed region from the queue.
fs_bounds.erase(iter);
}
}
}
// Simplify the bounds on each region and substitute global pipeline
// bounds for function regions which lower and upper bounds could not be
// determined.
map<string, Box> concrete_regions;
for (auto &f_reg : regions) {
simplify_box(f_reg.second);
Box concrete_box;
for (size_t i = 0; i < f_reg.second.size(); i++) {
Expr lower = f_reg.second[i].min;
Expr upper = f_reg.second[i].max;
auto iter = env.find(f_reg.first);
bool in_env = (iter != env.end());
if (!lower.as<IntImm>() && in_env) {
const Function &curr_f = iter->second;
for (const auto &b : curr_f.schedule().estimates()) {
size_t num_pure_args = curr_f.args().size();
if ((i < num_pure_args) && (b.var == curr_f.args()[i])) {
lower = Expr(b.min.as<IntImm>()->value);
}
}
}
if (!upper.as<IntImm>() && in_env) {
const Function &curr_f = iter->second;
for (const auto &b : curr_f.schedule().estimates()) {
size_t num_pure_args = curr_f.args().size();
if ((i < num_pure_args) && (b.var == curr_f.args()[i])) {
const IntImm *bmin = b.min.as<IntImm>();
const IntImm *bextent = b.extent.as<IntImm>();
upper = Expr(bmin->value + bextent->value - 1);
}
}
}
Interval concrete_bounds = Interval(lower, upper);
concrete_box.push_back(concrete_bounds);
}
concrete_regions[f_reg.first] = concrete_box;
}
regions_required_cache[query].push_back(RegionsRequired(bounds, concrete_regions));
return concrete_regions;
}
// Return redundantly computed regions of producers ('prods') while computing a
// region of the function stage ('f', 'stage_num') specified by 'bounds'. 'var'
// is the dimension along which redundant computation is accounted for.
map<string, Box>
DependenceAnalysis::redundant_regions(Function f, int stage_num, string var,
const DimBounds &bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates) {
// Find the regions required to compute the region of 'f' specified
// by 'bounds'.
map<string, Box> regions = regions_required(
f, stage_num, bounds, prods, only_regions_computed, input_estimates);
// Shift the bounds by the size of the interval along the direction
// of var.
DimBounds shifted_bounds;
for (const auto &b : bounds) {
if (b.first == var) {
Expr len = b.second.max - b.second.min + 1;
Interval bound = Interval(b.second.min + len, b.second.max + len);
shifted_bounds[b.first] = bound;
} else {
shifted_bounds[b.first] = b.second;
}
}
// Find the regions required to compute the region of f specified
// by shifted_bounds.
map<string, Box> regions_shifted = regions_required(
f, stage_num, shifted_bounds, prods, only_regions_computed, input_estimates);
// Compute the overlaps between 'regions_shifted' and the original
// regions required.
map<string, Box> overlaps;
for (const auto ® : regions) {
auto iter = regions_shifted.find(reg.first);
if (iter == regions.end()) {
// It will be interesting to log cases where this actually happens
// i.e., the shifted regions do not contain a function that was
// there in the original regions.
continue;
}
const Box &b = reg.second;
const Box &b_shifted = iter->second;
// The boxes should be of the same size.
internal_assert(b.size() == b_shifted.size());
Box b_intersect;
for (uint32_t i = 0 ; i < b.size(); i++) {
b_intersect.push_back(Interval::make_intersection(b[i], b_shifted[i]));
}
// A function should appear once in the regions and therefore cannot
// already be present in the overlaps map.
internal_assert(overlaps.find(reg.first) == overlaps.end());
overlaps.emplace(reg.first, b_intersect);
}
// Simplify the bounds of each of the overlap regions.
for (auto &f : overlaps) {
simplify_box(f.second);
}
return overlaps;
}
// Return overlapping regions of producers ('prods') while computing a function
// stage along each of the dimensions.
vector<map<string, Box>>
DependenceAnalysis::overlap_regions(Function f, int stage_num,
const DimBounds &bounds,
const set<string> &prods,
bool only_regions_computed,
const Scope<Interval> *input_estimates) {
vector<map<string, Box>> conc_overlaps;
Definition def = get_stage_definition(f, stage_num);
const vector<Dim> &dims = def.schedule().dims();
// Get the redundant regions along each dimension of f.
for (int d = 0; d < (int)dims.size() - 1; d++) {
map<string, Box> conc_reg = redundant_regions(f, stage_num, dims[d].var, bounds,
prods, only_regions_computed, input_estimates);
conc_overlaps.push_back(conc_reg);
}
return conc_overlaps;
}
// Return the regions of each function required for computing the
// outputs of the pipeline.
map<string, Box> get_pipeline_bounds(DependenceAnalysis &analysis,
const vector<Function> &outputs,
const Scope<Interval> *input_estimates) {
map<string, Box> pipeline_bounds;
// Find the regions required for each of the outputs and merge them
// to compute the full pipeline_bounds.
for (const auto &out : outputs) {
DimBounds pure_bounds;
Box out_box;
// Use the estimates on the output for determining the output bounds.
// If there are duplicates, use the most recent estimate.
const auto &estimates = out.schedule().estimates();
for (const auto &arg : out.args()) {
int i;
for (i = estimates.size() - 1; i >= 0; --i) {
const auto &est = estimates[i];
if ((est.var == arg) && est.min.defined() && est.extent.defined()) {
Interval I = Interval(est.min, simplify(est.min + est.extent - 1));
pure_bounds.emplace(arg, I);
out_box.push_back(I);
break;
}
}
internal_assert(i >= 0) << "Could not find estimate for " << arg << "\n";
}
set<string> prods;
for (const pair<string, Function> &fpair : analysis.env) {
prods.insert(fpair.first);
}
map<string, Box> regions = analysis.regions_required(out, pure_bounds, prods,
false, input_estimates);
// Add the output region to the pipeline bounds as well.
regions.emplace(out.name(), out_box);
merge_regions(pipeline_bounds, regions);
}
return pipeline_bounds;
}
struct AutoSchedule {
struct Stage {
string function;
size_t stage;
Stage(const string &f, size_t s) : function(f), stage(s) {}
bool operator==(const Stage &other) const {
return (function == other.function) && (stage == other.stage);
}
bool operator<(const Stage &other) const {
return (function < other.function) || ((function == other.function) && (stage < other.stage));
}
};
const map<string, Function> &env;
// Contain maps from function name to realization order.
map<string, size_t> realization_order;
// Cache for storing all internal vars/rvars that have been declared during
// the course of schedule generation, to ensure that we don't introduce any
// duplicates in the string representation of the schedules.
map<string, VarOrRVar> internal_vars;
// Store the list of schedules applied to some function stages (most recent
// schedule is placed last in the list).
map<string, map<int, vector<string>>> func_schedules;
// Store the list of vars/rvars used in the schedule applied to some
// function stages.
map<string, map<int, set<string>>> used_vars;
AutoSchedule(const map<string, Function> &env, const vector<string> &order) : env(env) {
for (size_t i = 0; i < order.size(); ++i) {
realization_order.emplace(order[i], i);
}
// Allocate a slot in 'used_vars' for each function stages in the pipeline
for (const auto &iter : env) {
for (size_t i = 0; i < iter.second.updates().size() + 1; ++i) {
used_vars[iter.first][i];
}
}
}
// Given a function name, return a string representation of getting the
// function handle
string get_func_handle(const string &name) const {
size_t index = get_element(realization_order, name);
return "pipeline.get_func(" + std::to_string(index) + ")";
}
friend std::ostream& operator<<(std::ostream &stream, const AutoSchedule &sched) {
for (const auto &iter : sched.internal_vars) {
if (iter.second.is_rvar) {
stream << "RVar ";
} else {
stream << "Var ";
}
stream << iter.first << "(\"" << iter.first << "\");\n";
}
stream << "\n";
// Declare all the functions + schedules
std::ostringstream func_ss;
std::ostringstream schedule_ss;
for (const auto &f : sched.func_schedules) {
const string &fname = get_sanitized_name(f.first);
func_ss << "Func " << fname << " = " << sched.get_func_handle(f.first) << ";\n";
schedule_ss << "{\n";
// Declare all the Vars and RVars that are actually used in the schedule
const Function &func = get_element(sched.env, f.first);
for (size_t i = 0; i < func.args().size(); ++i) {
if (sched.used_vars.at(func.name()).at(0).find(func.args()[i])
!= sched.used_vars.at(func.name()).at(0).end()) {
schedule_ss << " Var " << func.args()[i] << " = "
<< fname << ".args()[" << i << "];\n";
}
}
set<string> declared_rvars;
for (size_t i = 0; i < func.updates().size(); ++i) {
const vector<ReductionVariable> &rvars = func.updates()[i].schedule().rvars();
const set<string> &var_list = sched.used_vars.at(func.name()).at(i);
for (size_t j = 0; j < rvars.size(); ++j) {
if ((var_list.find(rvars[j].var) == var_list.end()) ||
(declared_rvars.find(rvars[j].var) != declared_rvars.end())) {
continue;
}
declared_rvars.insert(rvars[j].var);
schedule_ss << " RVar " << rvars[j].var << "("
<< fname << ".update(" << i << ").get_schedule().rvars()[" << j << "].var);\n";
}
}
for (const auto &s : f.second) {
internal_assert(!s.second.empty());
schedule_ss << " " << fname;
if (s.first > 0) {
schedule_ss << ".update(" << std::to_string(s.first - 1) << ")";
}
for (size_t i = 0; i < s.second.size(); ++i) {
schedule_ss << "\n ." << s.second[i];
}
schedule_ss << ";\n";
}
schedule_ss << "}\n";
}
stream << func_ss.str() << "\n";
stream << schedule_ss.str() << "\n";
return stream;
}
void push_schedule(const string &stage_name, size_t stage_num,
const string &sched, const set<string> &vars) {
vector<string> v = split_string(stage_name, ".");
internal_assert(!v.empty());
used_vars[v[0]][stage_num].insert(vars.begin(), vars.end());
// If the previous schedule applied is the same as this one,
// there is no need to re-apply the schedule
auto &schedules = func_schedules[v[0]][stage_num];
if (schedules.empty()) {
schedules.push_back(sched);
} else {
if (schedules[schedules.size()-1] != sched) {
schedules.push_back(sched);
}
}
}
};
// Implement the grouping algorithm and the cost model for making the grouping
// choices.
struct Partitioner {
// GroupingChoice encodes the grouping of the 'prod' function into the 'cons' stage.
struct GroupingChoice {
string prod;
FStage cons;
GroupingChoice(const string &prod, const FStage &cons) : prod(prod), cons(cons) {}
bool operator==(const GroupingChoice &other) const {
return (prod == other.prod) && (cons == other.cons);
}
bool operator<(const GroupingChoice &other) const {
return (prod < other.prod) || ((prod == other.prod) && (cons < other.cons));
}
friend std::ostream& operator<<(std::ostream &stream, const GroupingChoice &choice) {
stream << "Choice: " << choice.prod << " -> " << choice.cons << '\n';
return stream;
}
};
// A group is a sub-pipeline with a single output. Members of a group are
// either inlined into the consumer functions within the group or computed
// at tiles of the output, specified by 'tile_sizes'.
//
// TODO: The restriction of computing either at the inline or tile level
// makes the space of scheduling choices for a group very tractable.
// However, the restriction might miss good schedules which can only be
// realized by computing the members of the group at different levels of
// the group.
//
// There are two approaches to extend the space of schedules considered:
// 1) Recursive grouping: Treat the problem of determining the compute levels
// within a group as a smaller instance of the grouping problem with
// different parameters for the input, output sizes, and cache model.
//
// 2) Tightening: Always compute a function at the lowest level possible
// without introducing redundant work. This is a restricted form of recursive
// grouping which does not explore the trade-off between redundant work and
// locality.
//
// Either approach can be implemented as a post process for each group
// after the initial grouping process finishes. The cost model may
// already make sub-optimal higher level partitioning when it is not aware
// of the benefits of the post processing. However, it should strictly be
// an improvement over the initial grouping. As a first step, it is good
// to make it a post process.
//
// Incorporating the recursive grouping process into the cost model can be
// tricky and can potentially make the cost of analyzing a group
// prohibitive, as it requires solving smaller instances of the grouping
// problem for analyzing each configuration. On the other hand, tightening
// can be integrated into the cost model with out significantly increasing
// the time to analyze a grouping configuration.
//
// TODO: Add sliding window optimizations. For start, it may be enough to
// implement sliding window as a post-pass by moving the store level of all
// the members of the group to the outermost serial loop. This could possibly
// be incorporated in the cost model with some effort. Line-buffering
// presents additional challenges for this post-processing strategy though.
// A typical line-buffer would use terrible tile size for tiling, but its
// performance will improve significantly once sliding window is turned on.
//
// TODO: Register tiling is an important transformation especially for
// benchmarks with significant reuse of the data (like matrix multiply and
// convolutional layers). The mechanism for realizing register tiling is to
// completely unroll small tiles of the innermost kernels. Unrolling
// interacts with vectorization, storage layout, and depends on the outer
// level tiling.
struct Group {
// The output stage representing the group.
FStage output;
// Functions that belong to the group.
vector<FStage> members;
// Members of the group which are inlined.
set<string> inlined;
// Tile sizes along dimensions of the output function of the group.
map<string, Expr> tile_sizes;
Group(const FStage &output, const vector<FStage> &members)
: output(output), members(members) {}
friend std::ostream& operator<<(std::ostream &stream, const Group &g) {
stream << "Output FStage: " << g.output << '\n';
stream << "Members: " << '{';
for (size_t i = 0; i < g.members.size(); ++i) {
if (i > 0) {
stream << ", ";
}
stream << g.members[i];
}
stream << "}" << '\n';
stream << "Inlined: " << '{';
for (auto iter = g.inlined.begin(); iter != g.inlined.end(); ++iter) {
if (std::distance(g.inlined.begin(), iter) > 0) {
stream << ", ";
}
stream << *iter;
}
stream << "}" << '\n';
stream << "Tile sizes: " << "{";
for (auto iter = g.tile_sizes.begin(); iter != g.tile_sizes.end(); ++iter) {
if (std::distance(g.tile_sizes.begin(), iter) > 0) {
stream << ", ";
}
stream << "(" << iter->first << ", " << iter->second << ")";
}
stream << "}" << '\n';
return stream;
}
};
// Result of the analysis of a group.
struct GroupAnalysis {
// Estimate of the arithmetic and memory cost for computing the group.
Cost cost;
// Estimate of the parallelism that can be exploited while computing
// the group.
Expr parallelism;
GroupAnalysis() : cost(Cost()) , parallelism(Expr()) {}
GroupAnalysis(const Cost &c, Expr p) : cost(c), parallelism(std::move(p)) {}
inline bool defined() const {
return cost.defined() && parallelism.defined();
}
void simplify() {
cost.simplify();
if (parallelism.defined()) {
parallelism = Internal::simplify(parallelism);
}
}
friend std::ostream& operator<<(std::ostream &stream, const GroupAnalysis &analysis) {
stream << "[arith cost:" << analysis.cost.arith << ", ";
stream << "memory cost:" << analysis.cost.memory << ", ";
stream << "parallelism:" << analysis.parallelism << "]\n";
return stream;
}
};