@@ -2001,6 +2001,113 @@ class RearrangeExpressions : public IRMutator2 {
2001
2001
}
2002
2002
};
2003
2003
2004
+ // Try generating vgathers instead of shuffles.
2005
+ // At present, we request VTCM memory with single page allocation flag for all
2006
+ // store_in allocations. So it's always safe to generate a vgather.
2007
+ // Expressions which generate vgathers are of the form:
2008
+ // out(x) = lut(foo(x))
2009
+ // For vgathers out and lut should be in VTCM in a single page.
2010
+ class ScatterGatherGenerator : public IRMutator2 {
2011
+ Scope<Interval> bounds;
2012
+ std::unordered_map<string, const Allocate *> allocations;
2013
+
2014
+ using IRMutator2::visit;
2015
+
2016
+ template <typename NodeType, typename T>
2017
+ NodeType visit_let (const T *op) {
2018
+ // We only care about vector lets.
2019
+ if (op->value .type ().is_vector ()) {
2020
+ bounds.push (op->name , bounds_of_expr_in_scope (op->value , bounds));
2021
+ }
2022
+ NodeType node = IRMutator2::visit (op);
2023
+ if (op->value .type ().is_vector ()) {
2024
+ bounds.pop (op->name );
2025
+ }
2026
+ return node;
2027
+ }
2028
+
2029
+ Expr visit (const Let *op) { return visit_let<Expr>(op); }
2030
+
2031
+ Stmt visit (const LetStmt *op) { return visit_let<Stmt>(op); }
2032
+
2033
+ Stmt visit (const Allocate *op) {
2034
+ // Create a map of the allocation
2035
+ allocations[op->name ] = op;
2036
+ return IRMutator2::visit (op);
2037
+ }
2038
+
2039
+ // Try to match expressions of the form:
2040
+ // out(x) = lut(foo(x))
2041
+ // to generate vgathers. Here, out and lut should have
2042
+ // store_in(MemoryType::VTCM) directive. If a vgather is found return Call
2043
+ // Expr to vgather, otherwise Expr().
2044
+ Expr make_gather (const Load *op, Expr dst_base, Expr dst_index) {
2045
+ Type ty = op->type ;
2046
+ const Allocate *alloc = allocations[op->name ];
2047
+ // The lut should be in VTCM.
2048
+ if (!alloc || alloc->memory_type != MemoryType::VTCM) {
2049
+ return Expr ();
2050
+ }
2051
+ // HVX has only 16 or 32-bit gathers. Predicated vgathers are not
2052
+ // supported yet.
2053
+ if (op->index .as <Ramp>() || !is_one (op->predicate ) || !ty.is_vector () ||
2054
+ ty.bits () == 8 ) {
2055
+ return Expr ();
2056
+ }
2057
+ Expr index = mutate (ty.bytes () * op->index );
2058
+ Interval index_bounds = bounds_of_expr_in_scope (index , bounds);
2059
+ if (ty.bits () == 16 && index_bounds.is_bounded ()) {
2060
+ Expr index_span = span_of_bounds (index_bounds);
2061
+ index_span = common_subexpression_elimination (index_span);
2062
+ index_span = simplify (index_span);
2063
+ // We need to downcast the index values to 16 bit signed. So all the
2064
+ // the indices must be less than 1 << 15.
2065
+ if (!can_prove (index_span < std::numeric_limits<int16_t >::max ())) {
2066
+ return Expr ();
2067
+ }
2068
+ }
2069
+ // Calculate the size of the buffer lut in bytes.
2070
+ Expr size = ty.bytes ();
2071
+ for (size_t i = 0 ; i < alloc->extents .size (); i++) {
2072
+ size *= alloc->extents [i];
2073
+ }
2074
+ Expr src = Variable::make (Handle (), op->name );
2075
+ Expr new_index = mutate (cast (ty.with_code (Type::Int), index ));
2076
+ dst_index = mutate (dst_index);
2077
+
2078
+ return Call::make (ty, " gather" , {dst_base, dst_index, src, size-1 , new_index},
2079
+ Call::Intrinsic);
2080
+ }
2081
+
2082
+ Stmt visit (const Store *op) {
2083
+ // HVX has only 16 or 32-bit gathers. Predicated vgathers are not
2084
+ // supported yet.
2085
+ Type ty = op->value .type ();
2086
+ if (!is_one (op->predicate ) || !ty.is_vector () || ty.bits () == 8 ) {
2087
+ return IRMutator2::visit (op);
2088
+ }
2089
+ // To use vgathers, the destination address must be VTCM memory.
2090
+ const Allocate *alloc = allocations[op->name ];
2091
+ if (!alloc || alloc->memory_type != MemoryType::VTCM) {
2092
+ return IRMutator2::visit (op);
2093
+ }
2094
+ // The source for a gather must also be a buffer in VTCM.
2095
+ if (op->index .as <Ramp>() && op->value .as <Load>()) {
2096
+ // Check for vgathers
2097
+ Expr dst_base = Variable::make (Handle (), op->name );
2098
+ Expr dst_index = op->index .as <Ramp>()->base ;
2099
+ Expr value = make_gather (op->value .as <Load>(), dst_base, dst_index);
2100
+ if (value.defined ()) {
2101
+ // Found a vgather.
2102
+ // Function make_gather already mutates all the call arguements,
2103
+ // so no need to mutate again.
2104
+ return Evaluate::make (value);
2105
+ }
2106
+ }
2107
+ return IRMutator2::visit (op);
2108
+ }
2109
+ };
2110
+
2004
2111
} // namespace
2005
2112
2006
2113
Stmt optimize_hexagon_shuffles (Stmt s, int lut_alignment) {
@@ -2017,6 +2124,12 @@ Stmt vtmpy_generator(Stmt s) {
2017
2124
return s;
2018
2125
}
2019
2126
2127
+ Stmt scatter_gather_generator (Stmt s) {
2128
+ // Generate vscatter-vgather instruction if target >= v65
2129
+ s = ScatterGatherGenerator ().mutate (s);
2130
+ return s;
2131
+ }
2132
+
2020
2133
Stmt optimize_hexagon_instructions (Stmt s, Target t, Scope<ModulusRemainder> &alignment_info) {
2021
2134
// Convert some expressions to an equivalent form which get better
2022
2135
// optimized in later stages for hexagon
0 commit comments