Skip to content

Commit

Permalink
Merge pull request halide#1055 from halide/thread_safety
Browse files Browse the repository at this point in the history
Attempt to make the compiler thread-safe.
  • Loading branch information
abadams committed Mar 1, 2016
2 parents b6748c8 + 871974d commit 65bbac2
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 144 deletions.
4 changes: 4 additions & 0 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <iostream>
#include <limits>
#include <sstream>
#include <mutex>

#include "IRPrinter.h"
#include "CodeGen_LLVM.h"
Expand Down Expand Up @@ -310,6 +311,9 @@ CodeGen_LLVM *CodeGen_LLVM::new_for_target(const Target &target,
}

void CodeGen_LLVM::initialize_llvm() {
static std::mutex initialize_llvm_mutex;
std::lock_guard<std::mutex> lock(initialize_llvm_mutex);

// Initialize the targets we want to generate code for which are enabled
// in llvm configuration
if (!llvm_initialized) {
Expand Down
39 changes: 20 additions & 19 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ class StoreCollector : public IRMutator {
public:
const std::string store_name;
const int store_stride, max_stores;
std::vector<LetStmt>& let_stmts;
std::vector<Store>& stores;
std::vector<Stmt> &let_stmts;
std::vector<Stmt> &stores;

StoreCollector(const std::string& name, int stride, int ms,
std::vector<LetStmt>& lets, std::vector<Store>& ss) :
StoreCollector(const std::string &name, int stride, int ms,
std::vector<Stmt> &lets,
std::vector<Stmt> &ss) :
store_name(name), store_stride(stride), max_stores(ms),
let_stmts(lets), stores(ss), collecting(true) {}
private:
Expand All @@ -41,7 +42,7 @@ class StoreCollector : public IRMutator {
// These are lets that we've encountered since the last collected
// store. If we collect another store, these "potential" lets
// become lets used by the collected stores.
std::vector<LetStmt> potential_lets;
std::vector<Stmt> potential_lets;

void visit(const Load *op) {
if (!collecting) {
Expand Down Expand Up @@ -97,7 +98,7 @@ class StoreCollector : public IRMutator {
}

// This store is good, collect it and replace with a no-op.
stores.push_back(*op);
stores.push_back(op);
stmt = Evaluate::make(0);

// Because we collected this store, we need to save the
Expand All @@ -115,7 +116,7 @@ class StoreCollector : public IRMutator {

// If we're still collecting, we need to save the let as a potential let.
if (collecting) {
potential_lets.push_back(*op);
potential_lets.push_back(op);
}
}

Expand All @@ -136,7 +137,7 @@ class StoreCollector : public IRMutator {
};

Stmt collect_strided_stores(Stmt stmt, const std::string& name, int stride, int max_stores,
std::vector<LetStmt> lets, std::vector<Store>& stores) {
std::vector<Stmt> lets, std::vector<Stmt> &stores) {

StoreCollector collect(name, stride, max_stores, lets, stores);
return collect.mutate(stmt);
Expand Down Expand Up @@ -535,9 +536,9 @@ class Interleaver : public IRMutator {
if (!op->rest.defined()) goto fail;

// Gather all the let stmts surrounding the first.
std::vector<LetStmt> let_stmts;
std::vector<Stmt> let_stmts;
while (let) {
let_stmts.push_back(*let);
let_stmts.push_back(let);
store = let->body.as<Store>();
let = let->body.as<LetStmt>();
}
Expand All @@ -560,8 +561,8 @@ class Interleaver : public IRMutator {
const int64_t expected_stores = stride == 1 ? lanes : stride;

// Collect the rest of the stores.
std::vector<Store> stores;
stores.push_back(*store);
std::vector<Stmt> stores;
stores.push_back(store);
Stmt rest = collect_strided_stores(op->rest, store->name,
stride, expected_stores,
let_stmts, stores);
Expand All @@ -584,7 +585,7 @@ class Interleaver : public IRMutator {
Buffer load_image;
Parameter load_param;
for (size_t i = 0; i < stores.size(); ++i) {
const Ramp *ri = stores[i].index.as<Ramp>();
const Ramp *ri = stores[i].as<Store>()->index.as<Ramp>();
internal_assert(ri);

// Mismatched store vector laness.
Expand All @@ -607,7 +608,7 @@ class Interleaver : public IRMutator {

// This case only triggers if we have an immediate load of the correct stride on the RHS.
// TODO: Could we consider mutating the RHS so that we can handle more complex Expr's than just loads?
const Load *load = stores[i].value.as<Load>();
const Load *load = stores[i].as<Store>()->value.as<Load>();
if (!load) goto fail;

const Ramp *ramp = load->index.as<Ramp>();
Expand All @@ -634,7 +635,7 @@ class Interleaver : public IRMutator {
}

if (j == 0) {
base = stores[i].index.as<Ramp>()->base;
base = stores[i].as<Store>()->index.as<Ramp>()->base;
}

// The offset is not between zero and the stride.
Expand All @@ -644,9 +645,9 @@ class Interleaver : public IRMutator {
if (args[j].defined()) goto fail;

if (stride == 1) {
args[j] = Load::make(t, load_name, stores[i].index, load_image, load_param);
args[j] = Load::make(t, load_name, stores[i].as<Store>()->index, load_image, load_param);
} else {
args[j] = stores[i].value;
args[j] = stores[i].as<Store>()->value;
}
}

Expand All @@ -665,8 +666,8 @@ class Interleaver : public IRMutator {

// Rewrap the let statements we pulled off.
while (!let_stmts.empty()) {
LetStmt let = let_stmts.back();
stmt = LetStmt::make(let.name, let.value, stmt);
const LetStmt *let = let_stmts.back().as<LetStmt>();
stmt = LetStmt::make(let->name, let->value, stmt);
let_stmts.pop_back();
}

Expand Down
14 changes: 7 additions & 7 deletions src/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ struct IRHandle : public IntrusivePtr<const IRNode> {
struct IntImm : public ExprNode<IntImm> {
int64_t value;

static IntImm *make(Type t, int64_t value) {
static const IntImm *make(Type t, int64_t value) {
internal_assert(t.is_int() && t.is_scalar()) << "IntImm must be a scalar Int\n";
internal_assert(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64)
<< "IntImm must be 8, 16, 32, or 64-bit\n";

if (t.bits() == 32 && value >= -8 && value <= 8 &&
!small_int_cache[(int)value + 8].ref_count.is_zero()) {
return &small_int_cache[(int)value + 8];
small_int_cache[(int)value + 8]) {
return small_int_cache[(int)value + 8];
}

IntImm *node = new IntImm;
Expand All @@ -149,14 +149,14 @@ struct IntImm : public ExprNode<IntImm> {

private:
/** ints from -8 to 8 */
EXPORT static IntImm small_int_cache[17];
EXPORT static const IntImm *small_int_cache[17];
};

/** Unsigned integer constants */
struct UIntImm : public ExprNode<UIntImm> {
uint64_t value;

static UIntImm *make(Type t, uint64_t value) {
static const UIntImm *make(Type t, uint64_t value) {
internal_assert(t.is_uint() && t.is_scalar())
<< "UIntImm must be a scalar UInt\n";
internal_assert(t.bits() == 1 || t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64)
Expand All @@ -175,7 +175,7 @@ struct UIntImm : public ExprNode<UIntImm> {
struct FloatImm : public ExprNode<FloatImm> {
double value;

static FloatImm *make(Type t, double value) {
static const FloatImm *make(Type t, double value) {
internal_assert(t.is_float() && t.is_scalar()) << "FloatImm must be a scalar Float\n";
FloatImm *node = new FloatImm;
node->type = t;
Expand All @@ -201,7 +201,7 @@ struct FloatImm : public ExprNode<FloatImm> {
struct StringImm : public ExprNode<StringImm> {
std::string value;

static StringImm *make(const std::string &val) {
static const StringImm *make(const std::string &val) {
StringImm *node = new StringImm;
node->type = Handle();
node->value = val;
Expand Down
124 changes: 75 additions & 49 deletions src/FastIntegerDivide.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <mutex>

#include "FastIntegerDivide.h"
#include "IntegerDivisionTable.h"

Expand All @@ -6,82 +8,107 @@ namespace Halide {
using namespace Halide::Internal::IntegerDivision;

namespace IntegerDivideTable {

Image<uint8_t> integer_divide_table_u8() {
static Image<uint8_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_u8[i][2];
im(i, 1) = table_runtime_u8[i][3];
static std::mutex initialize_lock;
std::lock_guard<std::mutex> lock_guard(initialize_lock);
{
static Image<uint8_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_u8[i][2];
im(i, 1) = table_runtime_u8[i][3];
}
}
return im;
}
return im;
}

Image<uint8_t> integer_divide_table_s8() {
static Image<uint8_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_s8[i][2];
im(i, 1) = table_runtime_s8[i][3];
static std::mutex initialize_lock;
std::lock_guard<std::mutex> lock_guard(initialize_lock);
{
static Image<uint8_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_s8[i][2];
im(i, 1) = table_runtime_s8[i][3];
}
}
return im;
}
return im;
}

Image<uint16_t> integer_divide_table_u16() {
static Image<uint16_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_u16[i][2];
im(i, 1) = table_runtime_u16[i][3];
static std::mutex initialize_lock;
std::lock_guard<std::mutex> lock_guard(initialize_lock);
{
static Image<uint16_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_u16[i][2];
im(i, 1) = table_runtime_u16[i][3];
}
}
return im;
}
return im;
}

Image<uint16_t> integer_divide_table_s16() {
static Image<uint16_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_s16[i][2];
im(i, 1) = table_runtime_s16[i][3];
static std::mutex initialize_lock;
std::lock_guard<std::mutex> lock_guard(initialize_lock);
{
static Image<uint16_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_s16[i][2];
im(i, 1) = table_runtime_s16[i][3];
}
}
return im;
}
return im;
}

Image<uint32_t> integer_divide_table_u32() {
static Image<uint32_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_u32[i][2];
im(i, 1) = table_runtime_u32[i][3];
static std::mutex initialize_lock;
std::lock_guard<std::mutex> lock_guard(initialize_lock);
{
static Image<uint32_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_u32[i][2];
im(i, 1) = table_runtime_u32[i][3];
}
}
return im;
}
return im;
}

Image<uint32_t> integer_divide_table_s32() {
static Image<uint32_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_s32[i][2];
im(i, 1) = table_runtime_s32[i][3];
static std::mutex initialize_lock;
std::lock_guard<std::mutex> lock_guard(initialize_lock);
{
static Image<uint32_t> im(256, 2);
static bool initialized = false;
if (!initialized) {
initialized = true;
for (size_t i = 0; i < 256; i++) {
im(i, 0) = table_runtime_s32[i][2];
im(i, 1) = table_runtime_s32[i][3];
}
}
return im;
}
return im;
}
}

Expand Down Expand Up @@ -200,4 +227,3 @@ Expr fast_integer_divide(Expr numerator, Expr denominator) {

}
}

7 changes: 4 additions & 3 deletions src/Function.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <set>
#include <stdlib.h>
#include <atomic>

#include "IR.h"
#include "Function.h"
Expand Down Expand Up @@ -208,7 +209,7 @@ class FreezeFunctions : public IRGraphVisitor {

// A counter to use in tagging random variables
namespace {
static int rand_counter = 0;
static std::atomic<int> rand_counter;
}

Function::Function() : contents(new FunctionContents) {
Expand Down Expand Up @@ -461,8 +462,8 @@ void Function::define_update(const vector<Expr> &_args, vector<Expr> values) {
}

for (int i = 0; i < counter.count; i++) {
contents.ptr->ref_count.decrement();
internal_assert(!contents.ptr->ref_count.is_zero());
int count = contents.ptr->ref_count.decrement();
internal_assert(count != 0);
}

// First add any reduction domain
Expand Down
Loading

0 comments on commit 65bbac2

Please sign in to comment.