Skip to content

Commit cfa0905

Browse files
authored
Merge pull request halide#3280 from halide/perform_inline_with_order
Use realization order when inlining to avoid extra works
2 parents e5d1fd2 + 8b73456 commit cfa0905

5 files changed

+45
-17
lines changed

src/AutoSchedule.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -2999,10 +2999,10 @@ Partitioner::analyze_spatial_locality(const FStage &stg,
29992999
Definition def = get_stage_definition(stg.func, stg.stage_num);
30003000
// Perform inlining on the all the values and the args in the stage.
30013001
for (auto &val : def.values()) {
3002-
val = perform_inline(val, dep_analysis.env, inlines);
3002+
val = perform_inline(val, dep_analysis.env, inlines, dep_analysis.order);
30033003
}
30043004
for (auto &arg : def.args()) {
3005-
arg = perform_inline(arg, dep_analysis.env, inlines);
3005+
arg = perform_inline(arg, dep_analysis.env, inlines, dep_analysis.order);
30063006
}
30073007
def.accept(&find);
30083008

@@ -3433,7 +3433,7 @@ string generate_schedules(const vector<Function> &outputs, const Target &target,
34333433
// Initialize the cost model.
34343434
// Compute the expression costs for each function in the pipeline.
34353435
debug(2) << "Initializing region costs...\n";
3436-
RegionCosts costs(env);
3436+
RegionCosts costs(env, order);
34373437
if (debug::debug_level() >= 3) {
34383438
costs.disp_func_costs();
34393439
}
@@ -3468,7 +3468,7 @@ string generate_schedules(const vector<Function> &outputs, const Target &target,
34683468
debug(2) << "Re-computing function value bounds...\n";
34693469
func_val_bounds = compute_function_value_bounds(order, env);
34703470
debug(2) << "Re-initializing region costs...\n";
3471-
RegionCosts costs(env);
3471+
RegionCosts costs(env, order);
34723472
debug(2) << "Re-initializing dependence analysis...\n";
34733473
dep_analysis = DependenceAnalysis(env, order, func_val_bounds);
34743474
debug(2) << "Re-computing pipeline bounds...\n";

src/AutoScheduleUtils.cpp

+19-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ vector<DimBounds> get_stage_bounds(Function f, const DimBounds &pure_bounds) {
155155
}
156156

157157
Expr perform_inline(Expr e, const map<string, Function> &env,
158-
const set<string> &inlines) {
158+
const set<string> &inlines,
159+
const vector<string> &order) {
159160
if (inlines.empty()) {
160161
return e;
161162
}
@@ -168,8 +169,24 @@ Expr perform_inline(Expr e, const map<string, Function> &env,
168169
// Find all the function calls in the current expression.
169170
FindAllCalls find;
170171
inlined_expr.accept(&find);
171-
const set<string> &calls = find.funcs_called;
172+
const set<string> &calls_unsorted = find.funcs_called;
173+
174+
vector<string> calls(calls_unsorted.begin(), calls_unsorted.end());
175+
// Sort 'calls' based on the realization order in descending order
176+
// if provided (i.e. last to be realized comes first).
177+
if (!order.empty()) {
178+
std::sort(calls.begin(), calls.end(),
179+
[&order](const string &lhs, const string &rhs){
180+
const auto &iter_lhs = std::find(order.begin(), order.end(), lhs);
181+
const auto &iter_rhs = std::find(order.begin(), order.end(), rhs);
182+
return iter_lhs > iter_rhs;
183+
}
184+
);
185+
}
186+
172187
// Check if any of the calls are in the set of functions to be inlined.
188+
// Inline from the last function to be realized to avoid extra
189+
// inlining works.
173190
for (const auto &call : calls) {
174191
if (inlines.find(call) != inlines.end()) {
175192
Function prod_func = env.at(call);

src/AutoScheduleUtils.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,12 @@ DimBounds get_stage_bounds(Function f, int stage_num, const DimBounds &pure_boun
8484
std::vector<DimBounds> get_stage_bounds(Function f, const DimBounds &pure_bounds);
8585

8686
/** Recursively inline all the functions in the set 'inlines' into the
87-
* expression 'e' and return the resulting expression. */
87+
* expression 'e' and return the resulting expression. If 'order' is
88+
* passed, inlining will be done in the reverse order of function realization
89+
* to avoid extra inlining works. */
8890
Expr perform_inline(Expr e, const std::map<std::string, Function> &env,
89-
const std::set<std::string> &inlines = std::set<std::string>());
91+
const std::set<std::string> &inlines = std::set<std::string>(),
92+
const std::vector<std::string> &order = std::vector<std::string>());
9093

9194
/** Return all functions that are directly called by a function stage (f, stage). */
9295
std::set<std::string> get_parents(Function f, int stage);

src/RegionCosts.cpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,9 @@ map<string, Expr> compute_expr_detailed_byte_loads(Expr expr) {
324324

325325
} // anonymous namespace
326326

327-
RegionCosts::RegionCosts(const map<string, Function> &_env) : env(_env) {
327+
RegionCosts::RegionCosts(const map<string, Function> &_env,
328+
const vector<string> &_order)
329+
: env(_env), order(_order) {
328330
for (const auto &kv : env) {
329331
// Pre-compute the function costs without any inlining.
330332
func_cost[kv.first] = get_func_cost(kv.second);
@@ -442,7 +444,7 @@ RegionCosts::stage_detailed_load_costs(string func, int stage,
442444
} else {
443445
Definition def = get_stage_definition(curr_f, stage);
444446
for (const auto &e : def.values()) {
445-
Expr inlined_expr = perform_inline(e, env, inlines);
447+
Expr inlined_expr = perform_inline(e, env, inlines, order);
446448
inlined_expr = simplify(inlined_expr);
447449

448450
map<string, Expr> expr_load_costs = compute_expr_detailed_byte_loads(inlined_expr);
@@ -553,7 +555,8 @@ RegionCosts::detailed_load_costs(const map<string, Box> &regions,
553555
return load_costs;
554556
}
555557

556-
Cost RegionCosts::get_func_stage_cost(const Function &f, int stage, const set<string> &inlines) {
558+
Cost RegionCosts::get_func_stage_cost(const Function &f, int stage,
559+
const set<string> &inlines) {
557560
if (f.has_extern_definition()) {
558561
return Cost();
559562
}
@@ -563,7 +566,7 @@ Cost RegionCosts::get_func_stage_cost(const Function &f, int stage, const set<st
563566
Cost cost(0, 0);
564567

565568
for (const auto &e : def.values()) {
566-
Expr inlined_expr = perform_inline(e, env, inlines);
569+
Expr inlined_expr = perform_inline(e, env, inlines, order);
567570
inlined_expr = simplify(inlined_expr);
568571

569572
Cost expr_cost = compute_expr_cost(inlined_expr);
@@ -578,7 +581,7 @@ Cost RegionCosts::get_func_stage_cost(const Function &f, int stage, const set<st
578581

579582
if (!f.is_pure()) {
580583
for (const auto &arg : def.args()) {
581-
Expr inlined_arg = perform_inline(arg, env, inlines);
584+
Expr inlined_arg = perform_inline(arg, env, inlines, order);
582585
inlined_arg = simplify(inlined_arg);
583586

584587
Cost expr_cost = compute_expr_cost(inlined_arg);
@@ -639,7 +642,7 @@ Expr RegionCosts::region_footprint(const map<string, Box> &regions,
639642
}
640643
}
641644

642-
vector<string> order = topological_order(outs, env);
645+
vector<string> top_order = topological_order(outs, env);
643646

644647
Expr working_set_size = make_zero(Int(64));
645648
Expr curr_size = make_zero(Int(64));
@@ -657,7 +660,7 @@ Expr RegionCosts::region_footprint(const map<string, Box> &regions,
657660
}
658661
}
659662

660-
for (const auto &f : order) {
663+
for (const auto &f : top_order) {
661664
if (regions.find(f) != regions.end()) {
662665
curr_size += get_element(func_sizes, f);
663666
}

src/RegionCosts.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ struct Cost {
4141
struct RegionCosts {
4242
/** An environment map which contains all functions in the pipeline. */
4343
std::map<std::string, Function> env;
44+
/** Realization order of functions in the pipeline. The first function to
45+
* be realized comes first. */
46+
std::vector<std::string> order;
4447
/** A map containing the cost of computing a value in each stage of a
4548
* function. The number of entries in the vector is equal to the number of
4649
* stages in the function. */
@@ -130,8 +133,10 @@ struct RegionCosts {
130133
void disp_func_costs();
131134

132135
/** Construct a region cost object for the pipeline. 'env' is a map of all
133-
* functions in the pipeline.*/
134-
RegionCosts(const std::map<std::string, Function> &env);
136+
* functions in the pipeline. 'order' is the realization order of functions
137+
* in the pipeline. The first function to be realized comes first. */
138+
RegionCosts(const std::map<std::string, Function> &env,
139+
const std::vector<std::string> &order);
135140
};
136141

137142
/** Return true if the cost of inlining a function is equivalent to the

0 commit comments

Comments
 (0)