Skip to content

Commit 3ce1f20

Browse files
authored
Merge pull request halide#3279 from halide/compute_with_remove_is_right_level
Remove is_the_right_level since compute_at aliasing at fused group is not currently supported
2 parents 9c51d7d + c34fe9e commit 3ce1f20

File tree

2 files changed

+23
-57
lines changed

2 files changed

+23
-57
lines changed

src/Func.h

+21
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,27 @@ class Stage {
191191
* the stage we are calling compute_with on should not have specializations,
192192
* e.g. f2.compute_with(f1, x) is allowed only if f2 has no specializations.
193193
*
194+
* Also, if a producer is desired to be computed at the fused loop level,
195+
* the function passed to the compute_at() needs to be the "parent". Consider
196+
* the following code:
197+
\code
198+
input(x, y) = x + y;
199+
f(x, y) = input(x, y);
200+
f(x, y) += 5;
201+
g(x, y) = x - y;
202+
g(x, y) += 10;
203+
f.compute_with(g, y);
204+
f.update().compute_with(g.update(), y);
205+
\endcode
206+
*
207+
* To compute 'input' at the fused loop level at dimension y, we specify
208+
* input.compute_at(g, y) instead of input.compute_at(f, y) since 'g' is
209+
* the "parent" for this fused loop (i.e. 'g' is computed first before 'f'
210+
* is computed). On the other hand, to compute 'input' at the innermost
211+
* dimension of 'f', we specify input.compute_at(f, x) instead of
212+
* input.compute_at(g, x) since the x dimension of 'f' is not fused
213+
* (only the y dimension is).
214+
*
194215
* Given the constraints, this has a variety of uses. Consider the
195216
* following code:
196217
\code

src/ScheduleFunctions.cpp

+2-57
Original file line numberDiff line numberDiff line change
@@ -795,61 +795,6 @@ class InjectRealization : public IRMutator2 {
795795
found_compute_level(false), target(t), env(env) {}
796796

797797
private:
798-
// Determine if 'loop_name' is the right level to inject produce/realize node
799-
// of 'func'. If 'loop_name' is a fused group, we should inject it at the
800-
// fused parent loop of the group.
801-
bool is_the_right_level(const string &loop_name) {
802-
if (loop_name == LoopLevel::root().lock().to_string()) {
803-
return true;
804-
}
805-
806-
vector<string> v = split_string(loop_name, ".");
807-
internal_assert(v.size() > 2);
808-
const string &func_name = v[0];
809-
const string &var = v[v.size()-1];
810-
811-
int stage = -1;
812-
for (size_t i = 1; i < v.size() - 1; ++i) {
813-
if (v[i].substr(0, 1) == "s") {
814-
string str = v[i].substr(1, v[i].size() - 1);
815-
bool has_only_digits = (str.find_first_not_of( "0123456789" ) == string::npos);
816-
if (has_only_digits) {
817-
stage = atoi(str.c_str());
818-
}
819-
}
820-
}
821-
internal_assert(stage >= 0);
822-
823-
const auto &it = env.find(func_name);
824-
internal_assert(it != env.end());
825-
const Function &f = it->second;
826-
internal_assert(stage <= (int)f.updates().size());
827-
828-
if (f.has_extern_definition()) {
829-
return true;
830-
}
831-
832-
const Definition &def = (stage == 0) ? f.definition() : f.update(stage - 1);
833-
const LoopLevel &fuse_level = def.schedule().fuse_level().level;
834-
if (fuse_level.is_inlined() || fuse_level.is_root()) {
835-
// It isn't fused to anyone
836-
return true;
837-
} else {
838-
// Need to find out if it is fused at 'var'
839-
const vector<Dim> &dims = def.schedule().dims();
840-
const auto &it1 = std::find_if(dims.begin(), dims.end(),
841-
[&fuse_level](const Dim &d) { return var_name_match(d.var, fuse_level.var().name()); });
842-
internal_assert(it1 != dims.end());
843-
844-
const auto &it2 = std::find_if(dims.begin(), dims.end(),
845-
[&var](const Dim &d) { return var_name_match(d.var, var); });
846-
internal_assert(it2 != dims.end());
847-
848-
return it2 < it1;
849-
}
850-
return false;
851-
}
852-
853798
Stmt build_pipeline(Stmt consumer) {
854799
pair<Stmt, Stmt> realization = build_production(env, func, target);
855800

@@ -951,7 +896,7 @@ class InjectRealization : public IRMutator2 {
951896

952897
body = mutate(body);
953898

954-
if (compute_level.match(for_loop->name) && is_the_right_level(for_loop->name)) {
899+
if (compute_level.match(for_loop->name)) {
955900
debug(3) << "Found compute level\n";
956901
if (!function_is_already_realized_in_stmt(func, body) &&
957902
(function_is_used_in_stmt(func, body) || is_output)) {
@@ -960,7 +905,7 @@ class InjectRealization : public IRMutator2 {
960905
found_compute_level = true;
961906
}
962907

963-
if (store_level.match(for_loop->name) && is_the_right_level(for_loop->name)) {
908+
if (store_level.match(for_loop->name)) {
964909
debug(3) << "Found store level\n";
965910
internal_assert(found_compute_level)
966911
<< "The compute loop level was not found within the store loop level!\n";

0 commit comments

Comments
 (0)