Skip to content

Commit

Permalink
Add special case for printing broadcast shuffles. (halide#5565)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharletg authored Dec 16, 2020
1 parent 94da4f6 commit 34d35a3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,10 @@ Expr unbroadcast_lossless_cast(Type ty, Expr x) {
}
// Check if shuffle can be treated as a broadcast.
if (const Shuffle *shuff = x.as<Shuffle>()) {
int factor = ty.lanes();
if (shuff->is_broadcast(factor)) {
int factor = x.type().lanes() / ty.lanes();
if (shuff->is_broadcast() && shuff->broadcast_factor() % factor == 0) {
x = Shuffle::make(shuff->vectors, std::vector<int>(shuff->indices.begin(),
shuff->indices.begin() + factor));
shuff->indices.begin() + ty.lanes()));
}
}
}
Expand Down
34 changes: 27 additions & 7 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,9 @@ Expr Shuffle::make_concat(const std::vector<Expr> &vectors) {
return make(vectors, indices);
}

Expr Shuffle::make_broadcast(Expr vector, int lanes) {
std::vector<int> indices(lanes * vector.type().lanes());
for (int ix = 0; ix < lanes; ix++) {
Expr Shuffle::make_broadcast(Expr vector, int factor) {
std::vector<int> indices(factor * vector.type().lanes());
for (int ix = 0; ix < factor; ix++) {
std::iota(indices.begin() + ix * vector.type().lanes(),
indices.begin() + (ix + 1) * vector.type().lanes(), 0);
}
Expand All @@ -791,20 +791,40 @@ Expr Shuffle::make_extract_element(Expr vector, int i) {
return make_slice(std::move(vector), i, 1, 1);
}

bool Shuffle::is_broadcast(int factor) const {
bool Shuffle::is_broadcast() const {
int lanes = indices.size();
// Don't consider broadcast factor < 2
if (factor < 2 || factor > lanes) {
int factor = broadcast_factor();
if (factor == 0 || factor >= lanes) {
return false;
}
int broadcasted_lanes = lanes / factor;

if (broadcasted_lanes < 2 || broadcasted_lanes >= lanes || lanes % broadcasted_lanes != 0) {
return false;
}
for (int i = 0; i < lanes; i++) {
if (indices[i % factor] != indices[i]) {
if (indices[i % broadcasted_lanes] != indices[i]) {
return false;
}
}
return true;
}

int Shuffle::broadcast_factor() const {
int lanes = indices.size();
int broadcasted_lanes = 0;
for (; broadcasted_lanes < lanes; broadcasted_lanes++) {
if (indices[broadcasted_lanes] != broadcasted_lanes) {
break;
}
}
if (broadcasted_lanes > 0) {
return lanes / broadcasted_lanes;
} else {
return 0;
}
}

bool Shuffle::is_interleave() const {
int lanes = vectors.front().type().lanes();

Expand Down
5 changes: 3 additions & 2 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ struct Shuffle : public ExprNode<Shuffle> {

/** Convenience constructor for making a shuffle representing a
* broadcast of a vector. */
static Expr make_broadcast(Expr vector, int lanes);
static Expr make_broadcast(Expr vector, int factor);

/** Convenience constructor for making a shuffle representing a
* contiguous subset of a vector. */
Expand All @@ -779,7 +779,8 @@ struct Shuffle : public ExprNode<Shuffle> {
* A uint8 shuffle of with 4*n lanes and indices:
* 0, 1, 2, 3, 0, 1, 2, 3, ....., 0, 1, 2, 3
* can be represented as a uint32 broadcast with n lanes (factor = 4). */
bool is_broadcast(int factor) const;
bool is_broadcast() const;
int broadcast_factor() const;

/** Check if this shuffle is a concatenation of the vector
* arguments. */
Expand Down
4 changes: 4 additions & 0 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,10 @@ void IRPrinter::visit(const Shuffle *op) {
<< ", " << op->slice_stride()
<< ", " << op->indices.size()
<< ")";
} else if (op->is_broadcast()) {
stream << "broadcast(";
print_list(op->vectors);
stream << ", " << op->broadcast_factor() << ")";
} else {
stream << "shuffle(";
print_list(op->vectors);
Expand Down

0 comments on commit 34d35a3

Please sign in to comment.