@@ -28,9 +28,15 @@ using std::pair;
28
28
using std::make_pair;
29
29
using std::ostringstream;
30
30
31
- // Immediates and broadcasts of immediates
31
+ // Things that we can constant fold: Immediates and broadcasts of
32
+ // immediates.
32
33
bool is_simple_const (Expr e) {
33
- return (!e.as <Cast>()) && is_const (e);
34
+ if (e.as <IntImm>()) return true ;
35
+ if (e.as <FloatImm>()) return true ;
36
+ if (const Broadcast *b = e.as <Broadcast>()) {
37
+ return is_simple_const (b->value );
38
+ }
39
+ return false ;
34
40
}
35
41
36
42
// Is a constant representable as a certain type
@@ -62,8 +68,8 @@ class Simplify : public IRMutator {
62
68
bounds_info.set_containing_scope (bi);
63
69
}
64
70
65
- /*
66
71
// Uncomment to debug all Expr mutations.
72
+ /*
67
73
Expr mutate(Expr e) {
68
74
Expr new_e = IRMutator::mutate(e);
69
75
debug(0) << e << " -> " << new_e << "\n";
@@ -684,6 +690,12 @@ class Simplify : public IRMutator {
684
690
}
685
691
}
686
692
693
+ ModulusRemainder mod_rem (0 , 1 );
694
+ if (ramp_a && ramp_a->base .type () == Int (32 )) {
695
+ // Do modulus remainder analysis on the base.
696
+ mod_rem = modulus_remainder (ramp_a->base , alignment_info);
697
+ }
698
+
687
699
if (is_zero (a) && !is_zero (b)) {
688
700
expr = a;
689
701
} else if (is_one (b)) {
@@ -702,19 +714,17 @@ class Simplify : public IRMutator {
702
714
}
703
715
} else if (broadcast_a && broadcast_b) {
704
716
expr = mutate (Broadcast::make (Div::make (broadcast_a->value , broadcast_b->value ), broadcast_a->width ));
705
- } else if (ramp_a && broadcast_b &&
706
- const_int (broadcast_b->value , &ib) && ib &&
707
- const_int (ramp_a->stride , &ia) && ((ia % ib) == 0 )) {
708
- // ramp(x, ia, w) / broadcast(ib, w) -> ramp(x/ib, ia/ib, w) when ib divides ia
709
- expr = mutate (Ramp::make (ramp_a->base /ib, ia/ib, ramp_a->width ));
710
- } else if (ramp_a && broadcast_b &&
711
- mul_a_a && const_int (mul_a_a->b , &ia) && ia &&
712
- const_int (broadcast_b->value , &ib) && ib &&
713
- const_int (ramp_a->stride , &ic) &&
714
- (ib % ia) == 0 &&
715
- std::abs (ic * (broadcast_b->width - 1 )) < std::abs (ia)) {
716
- // ramp(x*a, c, w) / broadcast(b, w) -> broadcast(x / (b/a), w) when c*(w-1) < a and a divides b
717
- expr = mutate (Broadcast::make (mul_a_a->a / div_imp (ib, ia), broadcast_b->width ));
717
+ } else if (ramp_a && const_int (ramp_a->stride , &ia) &&
718
+ broadcast_b && const_int (broadcast_b->value , &ib) && ib &&
719
+ ia % ib == 0 ) {
720
+ // ramp(x, 4, w) / broadcast(2, w) -> ramp(x / 2, 2, w)
721
+ expr = mutate (Ramp::make (ramp_a->base / ib, div_imp (ia, ib), ramp_a->width ));
722
+ } else if (ramp_a && ramp_a->base .type () == Int (32 ) && const_int (ramp_a->stride , &ia) &&
723
+ broadcast_b && const_int (broadcast_b->value , &ib) && ib != 0 &&
724
+ mod_rem.modulus % ib == 0 &&
725
+ div_imp (mod_rem.remainder , ib) == div_imp (mod_rem.remainder + (ramp_a->width -1 )*ia, ib)) {
726
+ // ramp(k*z + x, y, w) / z = broadcast(k, w) if x/z == (x + (w-1)*y)/z
727
+ expr = mutate (Broadcast::make (ramp_a->base / ib, ramp_a->width ));
718
728
} else if (div_a &&
719
729
const_int (div_a->b , &ia) && ia >= 0 &&
720
730
const_int (b, &ib) && ib >= 0 ) {
@@ -768,7 +778,7 @@ class Simplify : public IRMutator {
768
778
Expr a = mutate (op->a );
769
779
Expr b = mutate (op->b );
770
780
771
- int ia = 0 , ia2 = 0 , ib = 0 ;
781
+ int ia = 0 , ib = 0 ;
772
782
float fa = 0 .0f , fb = 0 .0f ;
773
783
const Broadcast *broadcast_a = a.as <Broadcast>();
774
784
const Broadcast *broadcast_b = b.as <Broadcast>();
@@ -792,6 +802,14 @@ class Simplify : public IRMutator {
792
802
mod_rem = modulus_remainder (a, alignment_info);
793
803
}
794
804
805
+ // If the RHS is a constant and the LHS is a ramp, do modulus
806
+ // remainder analysis on the base.
807
+ if (broadcast_b &&
808
+ const_int (broadcast_b->value , &ib) && ib &&
809
+ ramp_a && ramp_a->base .type () == Int (32 )) {
810
+ mod_rem = modulus_remainder (ramp_a->base , alignment_info);
811
+ }
812
+
795
813
if (is_zero (a) && !is_zero (b)) {
796
814
expr = a;
797
815
} else if (const_int (a, &ia) && const_int (b, &ib) && ib) {
@@ -823,12 +841,18 @@ class Simplify : public IRMutator {
823
841
ia % ib == 0 ) {
824
842
// ramp(x, 4, w) % broadcast(2, w)
825
843
expr = mutate (Broadcast::make (ramp_a->base % ib, ramp_a->width ));
826
- } else if (ramp_a && const_int (ramp_a->base , &ia) &&
827
- const_int (ramp_a->stride , &ia2) &&
844
+ } else if (ramp_a && ramp_a->base .type () == Int (32 ) && const_int (ramp_a->stride , &ia) &&
845
+ broadcast_b && const_int (broadcast_b->value , &ib) && ib != 0 &&
846
+ mod_rem.modulus % ib == 0 &&
847
+ div_imp (mod_rem.remainder , ib) == div_imp (mod_rem.remainder + (ramp_a->width -1 )*ia, ib)) {
848
+ // ramp(k*z + x, y, w) % z = ramp(x, y, w) if x/z == (x + (w-1)*y)/z
849
+ expr = mutate (Ramp::make (mod_imp (mod_rem.remainder , ib), ramp_a->stride , ramp_a->width ));
850
+ } else if (ramp_a && ramp_a->base .type () == Int (32 ) &&
851
+ const_int (ramp_a->stride , &ia) && !is_const (ramp_a->base ) &&
828
852
broadcast_b && const_int (broadcast_b->value , &ib) && ib != 0 &&
829
- div_imp (ia, ib) == div_imp (ia + ramp_a-> width *ia2, ib) ) {
830
- // ramp(x, y, w) % broadcast(z, w) = ramp(x % z , y, w) if x/z == (x + w*y)/ z
831
- expr = mutate (Ramp::make (mod_imp (ia , ib), ramp_a->stride , ramp_a->width ));
853
+ mod_rem. modulus % ib == 0 ) {
854
+ // ramp(k*z + x, y, w) % z = ramp(x, y, w) % z
855
+ expr = mutate (Ramp::make (mod_imp (mod_rem. remainder , ib), ramp_a->stride , ramp_a->width ) % ib );
832
856
} else if (a.same_as (op->a ) && b.same_as (op->b )) {
833
857
expr = op;
834
858
} else {
@@ -2426,15 +2450,22 @@ void simplify_test() {
2426
2450
check (Expr (Broadcast::make (y, 4 )) / Expr (Broadcast::make (x, 4 )),
2427
2451
Expr (Broadcast::make (y/x, 4 )));
2428
2452
check (Expr (Ramp::make (x, 4 , 4 )) / 2 , Ramp::make (x/2 , 2 , 4 ));
2453
+ check (Expr (Ramp::make (x, -4 , 7 )) / 2 , Ramp::make (x/2 , -2 , 7 ));
2454
+ check (Expr (Ramp::make (x, 4 , 5 )) / -2 , Ramp::make (x/-2 , -2 , 5 ));
2455
+ check (Expr (Ramp::make (x, -8 , 5 )) / -2 , Ramp::make (x/-2 , 4 , 5 ));
2429
2456
2430
2457
check (Expr (Ramp::make (4 *x, 1 , 4 )) / 4 , Broadcast::make (x, 4 ));
2431
2458
check (Expr (Ramp::make (x*4 , 1 , 3 )) / 4 , Broadcast::make (x, 3 ));
2432
2459
check (Expr (Ramp::make (x*8 , 2 , 4 )) / 8 , Broadcast::make (x, 4 ));
2433
2460
check (Expr (Ramp::make (x*8 , 3 , 3 )) / 8 , Broadcast::make (x, 3 ));
2434
2461
check (Expr (Ramp::make (0 , 1 , 8 )) % 16 , Expr (Ramp::make (0 , 1 , 8 )));
2435
- check (Expr (Ramp::make (8 , 1 , 8 )) % 16 , Expr (Ramp::make (8 , 1 , 8 ) % 16 ));
2462
+ check (Expr (Ramp::make (8 , 1 , 8 )) % 16 , Expr (Ramp::make (8 , 1 , 8 )));
2463
+ check (Expr (Ramp::make (9 , 1 , 8 )) % 16 , Expr (Ramp::make (9 , 1 , 8 )) % 16 );
2436
2464
check (Expr (Ramp::make (16 , 1 , 8 )) % 16 , Expr (Ramp::make (0 , 1 , 8 )));
2437
- check (Expr (Ramp::make (0 , 1 , 8 )) % 8 , Expr (Ramp::make (0 , 1 , 8 ) % 8 ));
2465
+ check (Expr (Ramp::make (0 , 1 , 8 )) % 8 , Expr (Ramp::make (0 , 1 , 8 )));
2466
+ check (Expr (Ramp::make (x*8 +17 , 1 , 4 )) % 8 , Expr (Ramp::make (1 , 1 , 4 )));
2467
+ check (Expr (Ramp::make (x*8 +17 , 1 , 8 )) % 8 , Expr (Ramp::make (1 , 1 , 8 ) % 8 ));
2468
+
2438
2469
2439
2470
check (Expr (7 ) % 2 , 1 );
2440
2471
check (Expr (7 .25f ) % 2 .0f , 1 .25f );
0 commit comments