Skip to content

Commit

Permalink
openmp
Browse files Browse the repository at this point in the history
  • Loading branch information
xmlyqing00 committed Jan 17, 2019
1 parent 503772c commit 8bc4b8d
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 46 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ if (APPLE)
link_directories(/usr/local/lib)
endif()

# OpenMP
find_package(OpenMP)
if (OPENMP_FOUND)
# set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()

# Set default variables
set(SRC_ROOT_PATH ${DOC_REASSEMBLY_SOURCE_DIR}/src)
set(CMAKE_CXX_STANDARD 11)
Expand Down
2 changes: 2 additions & 0 deletions include/path_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <iostream>
#include <vector>
#include <map>
#include <atomic>
#include <algorithm>

#include <stripe_pair.h>
Expand All @@ -18,6 +19,7 @@ class PathManager {
map< vector<int>, pair<int,int> > sol_paths; // sol; word_cnt, sol_cnt;

PathManager(int _vertices_n, int _sols_n);
~PathManager();

void add_sol_words( const map< vector<int>, int > & sol_words);
void print_sol_paths();
Expand Down
6 changes: 4 additions & 2 deletions include/stripes_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>
#include <deque>
#include <string>
#include <omp.h>
#include <map>
#include <random>
#include <tesseract/baseapi.h>
Expand Down Expand Up @@ -72,7 +73,8 @@ class StripesSolver {
const vector<int> * sol_x=nullptr);

private:


omp_lock_t lock;
Metric metric_mode;
Composition composition_mode;
bool real_flag;
Expand All @@ -84,7 +86,7 @@ class StripesSolver {

// Tesseract
const string tesseract_model_path {"data/tesseract_model/"};
tesseract::TessBaseAPI * ocr;
// tesseract::TessBaseAPI * ocr;
const double conf_thres {80};

// Compatibility
Expand Down
1 change: 0 additions & 1 deletion include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ enum class PuzzleType {
};

const double eps = 1e-8;
const int U_a = 1.2;
const cv::Scalar seam_color_red(100, 100, 200);
const cv::Scalar seam_color_green(100, 200, 100);

Expand Down
5 changes: 5 additions & 0 deletions src/solver/path_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
PathManager::PathManager(int _nodes_n, int _sols_n) :
nodes_n(_nodes_n),
sols_n(_sols_n) {

}

PathManager::~PathManager() {
}

void PathManager::add_sol_words(const map< vector<int>, int > & sol_words) {
Expand All @@ -11,6 +15,7 @@ void PathManager::add_sol_words(const map< vector<int>, int > & sol_words) {

const vector<int> sol_path = iter.first;
int word_cnt = iter.second;

if (sol_paths.find(sol_path) != sol_paths.end()) {
auto val = sol_paths[sol_path];
val.first += word_cnt;
Expand Down
151 changes: 108 additions & 43 deletions src/solver/stripes_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,28 @@ StripesSolver::StripesSolver(const string & _puzzle_folder, int _stripes_n, int
path_manager(_stripes_n, sols_n),
real_flag(_real_flag) {

ocr = new tesseract::TessBaseAPI();
if (ocr->Init(tesseract_model_path.c_str(), "eng", tesseract::OEM_TESSERACT_ONLY)) {
cerr << "Could not initialize tesseract." << endl;
exit(-1);
}

string white_chars = "";
for (int i = 0; i < 10; i++) white_chars += to_string(i);
for (int i = 0; i < 26; i++) white_chars += char(int('A') + i);
for (int i = 0; i < 26; i++) white_chars += char(int('a') + i);
bool x = ocr->SetVariable("tessedit_char_whitelist", white_chars.c_str());
string black_chars = ",<.>/?;:\'\"[{]}\\|";
ocr->SetVariable("tessedit_char_blacklist", black_chars.c_str());

ocr->SetVariable("language_model_penalty_non_freq_dict_word", "5");
ocr->SetVariable("language_model_penalty_non_dict_word", "1");

#ifdef DEBUG
FILE * file = fopen("tmp/variables.txt", "w");
ocr->PrintVariables(file);
fclose(file);
#endif
// ocr = new tesseract::TessBaseAPI();
// if (ocr->Init(tesseract_model_path.c_str(), "eng", tesseract::OEM_TESSERACT_ONLY)) {
// cerr << "Could not initialize tesseract." << endl;
// exit(-1);
// }

// string white_chars = "";
// for (int i = 0; i < 10; i++) white_chars += to_string(i);
// for (int i = 0; i < 26; i++) white_chars += char(int('A') + i);
// for (int i = 0; i < 26; i++) white_chars += char(int('a') + i);
// bool x = ocr->SetVariable("tessedit_char_whitelist", white_chars.c_str());
// string black_chars = ",<.>/?;:\'\"[{]}\\|";
// ocr->SetVariable("tessedit_char_blacklist", black_chars.c_str());

// ocr->SetVariable("language_model_penalty_non_freq_dict_word", "5");
// ocr->SetVariable("language_model_penalty_non_dict_word", "1");

// #ifdef DEBUG
// FILE * file = fopen("tmp/variables.txt", "w");
// ocr->PrintVariables(file);
// fclose(file);
// #endif

// Read ground truth order.
ifstream fin(puzzle_folder + "order.txt", ios::in);
Expand All @@ -53,10 +53,14 @@ StripesSolver::StripesSolver(const string & _puzzle_folder, int _stripes_n, int
stripes.push_back(move(stripe_img));
}

omp_init_lock(&lock);

}

StripesSolver::~StripesSolver() {
ocr->End();
// ocr->End();

omp_destroy_lock(&lock);
}

void StripesSolver::save_result(const string & case_name, bool benchmark_flag) {
Expand Down Expand Up @@ -254,7 +258,23 @@ double StripesSolver::m_metric_char(const cv::Mat & piece0, const cv::Mat & piec
int margin_piece1;
cv::Mat && merged_img = merge_imgs(piece0, piece1, real_flag, &seam_x, &margin_piece1);


tesseract::TessBaseAPI * ocr = new tesseract::TessBaseAPI();
if (ocr->Init(tesseract_model_path.c_str(), "eng", tesseract::OEM_TESSERACT_ONLY)) {
cerr << "Could not initialize tesseract." << endl;
exit(-1);
}

string white_chars = "";
for (int i = 0; i < 10; i++) white_chars += to_string(i);
for (int i = 0; i < 26; i++) white_chars += char(int('A') + i);
for (int i = 0; i < 26; i++) white_chars += char(int('a') + i);
bool x = ocr->SetVariable("tessedit_char_whitelist", white_chars.c_str());
string black_chars = ",<.>/?;:\'\"[{]}\\|";
ocr->SetVariable("tessedit_char_blacklist", black_chars.c_str());

ocr->SetVariable("language_model_penalty_non_freq_dict_word", "5");
ocr->SetVariable("language_model_penalty_non_dict_word", "1");

const int max_m_width = min(piece0.cols, piece1.cols);
const tesseract::PageIteratorLevel tesseract_level {tesseract::RIL_SYMBOL};

Expand Down Expand Up @@ -296,6 +316,8 @@ double StripesSolver::m_metric_char(const cv::Mat & piece0, const cv::Mat & piec
// cv::waitKey();
#endif

ocr->End();

return -m_metric_score;

}
Expand Down Expand Up @@ -353,7 +375,8 @@ void StripesSolver::m_metric_word() {
bool valid_flag = false;

for (int j = 0; j < stripes_n; j++) {
if (i == j || pixel_graph[i][j] < 0) continue;
if (i == j) continue;
// if (!real_flag && pixel_graph[i][j] < 0) continue;
score_max = max(score_max, pixel_graph[i][j]);
score_min = min(score_min, pixel_graph[i][j]);
valid_flag = true;
Expand All @@ -363,7 +386,8 @@ void StripesSolver::m_metric_word() {

double score_delta = score_max - score_min;
for (int j = 0; j < stripes_n; j++) {
if (i == j || pixel_graph[i][j] < 0) continue;
if (i == j) continue;
// if (!real_flag && pixel_graph[i][j] < 0) continue;
double score = (pixel_graph[i][j] - score_min) / score_delta;
s_l[i][j] = exp(-score);
}
Expand All @@ -378,7 +402,8 @@ void StripesSolver::m_metric_word() {
bool valid_flag = false;

for (int i = 0; i < stripes_n; i++) {
if (i == j || pixel_graph[i][j] < 0) continue;
if (i == j) continue;
// if (!real_flag && pixel_graph[i][j] < 0) continue;
score_max = max(score_max, pixel_graph[i][j]);
score_min = min(score_min, pixel_graph[i][j]);
valid_flag = true;
Expand All @@ -388,7 +413,8 @@ void StripesSolver::m_metric_word() {

double score_delta = score_max - score_min;
for (int i = 0; i < stripes_n; i++) {
if (i == j || pixel_graph[i][j] < 0) continue;
if (i == j) continue;
// if (!real_flag && pixel_graph[i][j] < 0) continue;
double score = (pixel_graph[i][j] - score_min) / score_delta;
s_r[i][j] = exp(-score);
}
Expand All @@ -404,6 +430,9 @@ void StripesSolver::m_metric_word() {
}

// Compute stripe_pairs
double U_a = 1.1;
if (real_flag) U_a = 1.1;

vector< vector<StripePair> > compose_next;
for (int i = 0; i < stripes_n; i++) {

Expand All @@ -426,7 +455,7 @@ void StripesSolver::m_metric_word() {
for (int j = 0; j < next_pairs.size(); j++) {
double alpha = U_a * (next_pairs[j].m_score / worst_score - 1);
double exp_alpha = exp(alpha);
next_pairs[j].ac_prob = (exp_alpha - 1) / (exp_alpha + 1);
next_pairs[j].ac_prob = max((double)0, (exp_alpha - 1)) / (exp_alpha + 1);
}

compose_next.push_back(move(next_pairs));
Expand Down Expand Up @@ -457,23 +486,28 @@ void StripesSolver::m_metric_word() {
}

for (int i = 0; i < stripes_n; i++) {
for (int j = 0; j < stripes_n; j++) {
if (composition_cnt[i][j] == 0) continue;
cout << i << " " << j << " " << composition_cnt[i][j] << endl;
for (int j = 0; j < stripes_n - 1; j++) {
if (gt_order[j] == i) {
cout << "Occurence cnt " << i << " " << gt_order[j+1] << " " << composition_cnt[i][gt_order[j+1]] << endl;
break;
}
}
}

int sol_idx = 0;

cout << "Detect words on solution: ";

vector<int> sol_x;
for (const auto & sol: candidate_sols) {

#pragma omp parallel for num_threads(20)
// for (const auto & sol: candidate_sols) {
for (int i = 0; i < candidate_sols.size(); i++) {

const auto & sol = candidate_sols[i];

++sol_idx;
if (sol_idx % 20 == 0) cout << sol_idx << " " << flush;

sol_x.clear();
vector<int> sol_x;
cv::Mat composition_img = compose_img(sol, real_flag, &sol_x);
cv::Mat tmp_img = word_detection(composition_img, sol, sol_x);

Expand All @@ -483,6 +517,8 @@ void StripesSolver::m_metric_word() {
cv::imwrite("tmp/sol_" + to_string(sol_idx) + ".png", tmp_img);
// cv::imshow("Tmp img", tmp_img);
// cv::waitKey();
#else
if (sol_idx % 20 == 0) cout << sol_idx << " " << flush;
#endif
}
cout << endl;
Expand Down Expand Up @@ -564,14 +600,15 @@ void StripesSolver::m_metric() {

pixel_graph = vector< vector<double> >(stripes_n, vector<double>(stripes_n, 0));
double m_score_p, m_score_c;

#pragma omp parallel for num_threads(20)
for (int i = 0; i < stripes_n; i++) {

#pragma omp parallel for num_threads(20)
for (int j = 0; j < stripes_n; j++) {

if (i == j) continue;
#ifdef DEBUG
cout << "Init " << i << " " << j << " ";
#endif

double m_score = 0;
switch (metric_mode) {
case Metric::PIXEL:
Expand All @@ -581,20 +618,26 @@ void StripesSolver::m_metric() {
m_score = m_metric_char(stripes[i], stripes[j]);
case Metric::WORD:
m_score_p = m_metric_pixel(stripes[i], stripes[j], real_flag);
m_score_c = m_metric_char(stripes[i], stripes[j]);
m_score = m_score_p * 2 + m_score_c;
// if (real_flag) {
m_score_c = m_metric_char(stripes[i], stripes[j]);
m_score = m_score_p * 2 + m_score_c;
// } else {
// m_score = m_score_p;
// }
#ifdef DEBUG
cout << m_score << endl;
cout << "Init " << i << " " << j << " " << m_score << endl;
#endif
break;
default:
break;
}

omp_set_lock(&lock);
pixel_graph[i][j] = m_score;
if (m_score > -eps) {
if (m_score_p > -eps) {
stripe_pairs.push_back(StripePair(i, j, m_score, 1.0, true));
}
omp_unset_lock(&lock);

}

Expand All @@ -609,6 +652,7 @@ void StripesSolver::m_metric() {
#endif

if (metric_mode == Metric::WORD) {
cout << "[INFO] Calculate word metric." << endl;
m_metric_word();
};

Expand Down Expand Up @@ -776,6 +820,23 @@ cv::Mat StripesSolver::word_detection( const cv::Mat & img,
const tesseract::PageIteratorLevel tesseract_level {tesseract::RIL_WORD};
const cv::Scalar color_blue(200, 0, 0);

tesseract::TessBaseAPI * ocr = new tesseract::TessBaseAPI();
if (ocr->Init(tesseract_model_path.c_str(), "eng", tesseract::OEM_TESSERACT_ONLY)) {
cerr << "Could not initialize tesseract." << endl;
exit(-1);
}

string white_chars = "";
for (int i = 0; i < 10; i++) white_chars += to_string(i);
for (int i = 0; i < 26; i++) white_chars += char(int('A') + i);
for (int i = 0; i < 26; i++) white_chars += char(int('a') + i);
bool x = ocr->SetVariable("tessedit_char_whitelist", white_chars.c_str());
string black_chars = ",<.>/?;:\'\"[{]}\\|";
ocr->SetVariable("tessedit_char_blacklist", black_chars.c_str());

ocr->SetVariable("language_model_penalty_non_freq_dict_word", "5");
ocr->SetVariable("language_model_penalty_non_dict_word", "1");

ocr->SetImage(img.data, img.cols, img.rows, 3, img.step);
ocr->Recognize(0);

Expand Down Expand Up @@ -825,7 +886,11 @@ cv::Mat StripesSolver::word_detection( const cv::Mat & img,
} while (ocr_iter->Next(tesseract_level));
}

omp_set_lock(&lock);
path_manager.add_sol_words(sol_words);
omp_unset_lock(&lock);

ocr->End();

return img_bbox;

Expand Down

0 comments on commit 8bc4b8d

Please sign in to comment.