Skip to content

Commit c62501f

Browse files
committed
Merge pull request halide#496 from halide/ramp_div_mod_simplify_rules
Rework and add new rules for simplifying ramp/x and ramp%x
2 parents e071700 + f348390 commit c62501f

File tree

1 file changed

+55
-24
lines changed

1 file changed

+55
-24
lines changed

src/Simplify.cpp

+55-24
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@ using std::pair;
2828
using std::make_pair;
2929
using std::ostringstream;
3030

31-
// Immediates and broadcasts of immediates
31+
// Things that we can constant fold: Immediates and broadcasts of
32+
// immediates.
3233
bool is_simple_const(Expr e) {
33-
return (!e.as<Cast>()) && is_const(e);
34+
if (e.as<IntImm>()) return true;
35+
if (e.as<FloatImm>()) return true;
36+
if (const Broadcast *b = e.as<Broadcast>()) {
37+
return is_simple_const(b->value);
38+
}
39+
return false;
3440
}
3541

3642
// Is a constant representable as a certain type
@@ -62,8 +68,8 @@ class Simplify : public IRMutator {
6268
bounds_info.set_containing_scope(bi);
6369
}
6470

65-
/*
6671
// Uncomment to debug all Expr mutations.
72+
/*
6773
Expr mutate(Expr e) {
6874
Expr new_e = IRMutator::mutate(e);
6975
debug(0) << e << " -> " << new_e << "\n";
@@ -684,6 +690,12 @@ class Simplify : public IRMutator {
684690
}
685691
}
686692

693+
ModulusRemainder mod_rem(0, 1);
694+
if (ramp_a && ramp_a->base.type() == Int(32)) {
695+
// Do modulus remainder analysis on the base.
696+
mod_rem = modulus_remainder(ramp_a->base, alignment_info);
697+
}
698+
687699
if (is_zero(a) && !is_zero(b)) {
688700
expr = a;
689701
} else if (is_one(b)) {
@@ -702,19 +714,17 @@ class Simplify : public IRMutator {
702714
}
703715
} else if (broadcast_a && broadcast_b) {
704716
expr = mutate(Broadcast::make(Div::make(broadcast_a->value, broadcast_b->value), broadcast_a->width));
705-
} else if (ramp_a && broadcast_b &&
706-
const_int(broadcast_b->value, &ib) && ib &&
707-
const_int(ramp_a->stride, &ia) && ((ia % ib) == 0)) {
708-
// ramp(x, ia, w) / broadcast(ib, w) -> ramp(x/ib, ia/ib, w) when ib divides ia
709-
expr = mutate(Ramp::make(ramp_a->base/ib, ia/ib, ramp_a->width));
710-
} else if (ramp_a && broadcast_b &&
711-
mul_a_a && const_int(mul_a_a->b, &ia) && ia &&
712-
const_int(broadcast_b->value, &ib) && ib &&
713-
const_int(ramp_a->stride, &ic) &&
714-
(ib % ia) == 0 &&
715-
std::abs(ic * (broadcast_b->width - 1)) < std::abs(ia)) {
716-
// ramp(x*a, c, w) / broadcast(b, w) -> broadcast(x / (b/a), w) when c*(w-1) < a and a divides b
717-
expr = mutate(Broadcast::make(mul_a_a->a / div_imp(ib, ia), broadcast_b->width));
717+
} else if (ramp_a && const_int(ramp_a->stride, &ia) &&
718+
broadcast_b && const_int(broadcast_b->value, &ib) && ib &&
719+
ia % ib == 0) {
720+
// ramp(x, 4, w) / broadcast(2, w) -> ramp(x / 2, 2, w)
721+
expr = mutate(Ramp::make(ramp_a->base / ib, div_imp(ia, ib), ramp_a->width));
722+
} else if (ramp_a && ramp_a->base.type() == Int(32) && const_int(ramp_a->stride, &ia) &&
723+
broadcast_b && const_int(broadcast_b->value, &ib) && ib != 0 &&
724+
mod_rem.modulus % ib == 0 &&
725+
div_imp(mod_rem.remainder, ib) == div_imp(mod_rem.remainder + (ramp_a->width-1)*ia, ib)) {
726+
// ramp(k*z + x, y, w) / z = broadcast(k, w) if x/z == (x + (w-1)*y)/z
727+
expr = mutate(Broadcast::make(ramp_a->base / ib, ramp_a->width));
718728
} else if (div_a &&
719729
const_int(div_a->b, &ia) && ia >= 0 &&
720730
const_int(b, &ib) && ib >= 0) {
@@ -768,7 +778,7 @@ class Simplify : public IRMutator {
768778
Expr a = mutate(op->a);
769779
Expr b = mutate(op->b);
770780

771-
int ia = 0, ia2 = 0, ib = 0;
781+
int ia = 0, ib = 0;
772782
float fa = 0.0f, fb = 0.0f;
773783
const Broadcast *broadcast_a = a.as<Broadcast>();
774784
const Broadcast *broadcast_b = b.as<Broadcast>();
@@ -792,6 +802,14 @@ class Simplify : public IRMutator {
792802
mod_rem = modulus_remainder(a, alignment_info);
793803
}
794804

805+
// If the RHS is a constant and the LHS is a ramp, do modulus
806+
// remainder analysis on the base.
807+
if (broadcast_b &&
808+
const_int(broadcast_b->value, &ib) && ib &&
809+
ramp_a && ramp_a->base.type() == Int(32)) {
810+
mod_rem = modulus_remainder(ramp_a->base, alignment_info);
811+
}
812+
795813
if (is_zero(a) && !is_zero(b)) {
796814
expr = a;
797815
} else if (const_int(a, &ia) && const_int(b, &ib) && ib) {
@@ -823,12 +841,18 @@ class Simplify : public IRMutator {
823841
ia % ib == 0) {
824842
// ramp(x, 4, w) % broadcast(2, w)
825843
expr = mutate(Broadcast::make(ramp_a->base % ib, ramp_a->width));
826-
} else if (ramp_a && const_int(ramp_a->base, &ia) &&
827-
const_int(ramp_a->stride, &ia2) &&
844+
} else if (ramp_a && ramp_a->base.type() == Int(32) && const_int(ramp_a->stride, &ia) &&
845+
broadcast_b && const_int(broadcast_b->value, &ib) && ib != 0 &&
846+
mod_rem.modulus % ib == 0 &&
847+
div_imp(mod_rem.remainder, ib) == div_imp(mod_rem.remainder + (ramp_a->width-1)*ia, ib)) {
848+
// ramp(k*z + x, y, w) % z = ramp(x, y, w) if x/z == (x + (w-1)*y)/z
849+
expr = mutate(Ramp::make(mod_imp(mod_rem.remainder, ib), ramp_a->stride, ramp_a->width));
850+
} else if (ramp_a && ramp_a->base.type() == Int(32) &&
851+
const_int(ramp_a->stride, &ia) && !is_const(ramp_a->base) &&
828852
broadcast_b && const_int(broadcast_b->value, &ib) && ib != 0 &&
829-
div_imp(ia, ib) == div_imp(ia + ramp_a->width*ia2, ib)) {
830-
// ramp(x, y, w) % broadcast(z, w) = ramp(x % z, y, w) if x/z == (x + w*y)/z
831-
expr = mutate(Ramp::make(mod_imp(ia, ib), ramp_a->stride, ramp_a->width));
853+
mod_rem.modulus % ib == 0) {
854+
// ramp(k*z + x, y, w) % z = ramp(x, y, w) % z
855+
expr = mutate(Ramp::make(mod_imp(mod_rem.remainder, ib), ramp_a->stride, ramp_a->width) % ib);
832856
} else if (a.same_as(op->a) && b.same_as(op->b)) {
833857
expr = op;
834858
} else {
@@ -2426,15 +2450,22 @@ void simplify_test() {
24262450
check(Expr(Broadcast::make(y, 4)) / Expr(Broadcast::make(x, 4)),
24272451
Expr(Broadcast::make(y/x, 4)));
24282452
check(Expr(Ramp::make(x, 4, 4)) / 2, Ramp::make(x/2, 2, 4));
2453+
check(Expr(Ramp::make(x, -4, 7)) / 2, Ramp::make(x/2, -2, 7));
2454+
check(Expr(Ramp::make(x, 4, 5)) / -2, Ramp::make(x/-2, -2, 5));
2455+
check(Expr(Ramp::make(x, -8, 5)) / -2, Ramp::make(x/-2, 4, 5));
24292456

24302457
check(Expr(Ramp::make(4*x, 1, 4)) / 4, Broadcast::make(x, 4));
24312458
check(Expr(Ramp::make(x*4, 1, 3)) / 4, Broadcast::make(x, 3));
24322459
check(Expr(Ramp::make(x*8, 2, 4)) / 8, Broadcast::make(x, 4));
24332460
check(Expr(Ramp::make(x*8, 3, 3)) / 8, Broadcast::make(x, 3));
24342461
check(Expr(Ramp::make(0, 1, 8)) % 16, Expr(Ramp::make(0, 1, 8)));
2435-
check(Expr(Ramp::make(8, 1, 8)) % 16, Expr(Ramp::make(8, 1, 8) % 16));
2462+
check(Expr(Ramp::make(8, 1, 8)) % 16, Expr(Ramp::make(8, 1, 8)));
2463+
check(Expr(Ramp::make(9, 1, 8)) % 16, Expr(Ramp::make(9, 1, 8)) % 16);
24362464
check(Expr(Ramp::make(16, 1, 8)) % 16, Expr(Ramp::make(0, 1, 8)));
2437-
check(Expr(Ramp::make(0, 1, 8)) % 8, Expr(Ramp::make(0, 1, 8) % 8));
2465+
check(Expr(Ramp::make(0, 1, 8)) % 8, Expr(Ramp::make(0, 1, 8)));
2466+
check(Expr(Ramp::make(x*8+17, 1, 4)) % 8, Expr(Ramp::make(1, 1, 4)));
2467+
check(Expr(Ramp::make(x*8+17, 1, 8)) % 8, Expr(Ramp::make(1, 1, 8) % 8));
2468+
24382469

24392470
check(Expr(7) % 2, 1);
24402471
check(Expr(7.25f) % 2.0f, 1.25f);

0 commit comments

Comments
 (0)