Skip to content

Commit

Permalink
batching if huge number of descriptors
Browse files Browse the repository at this point in the history
  • Loading branch information
ducha-aiki committed Jul 11, 2019
1 parent 338f195 commit c30204a
Showing 1 changed file with 73 additions and 39 deletions.
112 changes: 73 additions & 39 deletions imagerepresentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
#ifdef _OPENMP
#include <omp.h>
#endif

#include <algorithm>
#include <iterator>

#define VERBOSE 1
#include <zmq.hpp>


inline static bool endsWith(const std::string& str, const std::string& suffix)
{
return str.size() >= suffix.size() && 0 == str.compare(str.size()-suffix.size(), suffix.size(), suffix);
Expand All @@ -20,51 +22,83 @@ std::vector<std::vector<float> > DescribeWithZmq(zmqDescriptorParams par,
AffineRegionVector &kps,
SynthImage &temp_img1){

int kp_size = kps.size() ;
std::vector<std::vector<float> > out;
cv::Mat patches;
int odd_patch_size = par.patchSize ;
if (kps.size() > 0){
ExtractPatchesColumn(kps,temp_img1, patches,
par.mrSize,
odd_patch_size,
false,
false,
true);
}
int batch_size = 2000;
std::cout << kp_size << std::endl;
if (kp_size <= batch_size){



std::vector<uchar> bufff1;
cv::imencode(".png",patches,bufff1);
cv::Mat patches;
int odd_patch_size = par.patchSize ;
if (kps.size() > 0){
ExtractPatchesColumn(kps,temp_img1, patches,
par.mrSize,
odd_patch_size,
false,
false,
true);

zmq::context_t context (1);
int socket_mode = ZMQ_REQ;

zmq::socket_t socket(context, socket_mode);
socket.connect(par.port);
std::vector<uchar> bufff1;
cv::imencode(".png",patches,bufff1);

zmq::message_t mods_to_cnn(bufff1.size()) ;
memcpy ((void *) mods_to_cnn.data (), bufff1.data(), bufff1.size());
zmq::message_t cnn_to_mods;
zmq::context_t context (1);
int socket_mode = ZMQ_REQ;

zmq::socket_t socket(context, socket_mode);
socket.connect(par.port);

zmq::message_t mods_to_cnn(bufff1.size()) ;
memcpy ((void *) mods_to_cnn.data (), bufff1.data(), bufff1.size());
zmq::message_t cnn_to_mods;
#pragma omp critical
{
socket.send(mods_to_cnn);
// Get the reply.
socket.recv(&cnn_to_mods);
}
std::vector<float> inMsg(cnn_to_mods.size() / sizeof(float));
std::memcpy(inMsg.data(), cnn_to_mods.data(), cnn_to_mods.size());
const int desc_size = inMsg.size() /kps.size() ;

for (int img_num=0; img_num<kps.size(); img_num++)
{
std::vector<float> curr_desc(desc_size);
int offset = img_num*desc_size;
for (int i = 0; i < desc_size; ++i) {
const float v1 = inMsg[i+offset];
curr_desc[i] = v1;
{
socket.send(mods_to_cnn);
// Get the reply.
socket.recv(&cnn_to_mods);
}
std::vector<float> inMsg(cnn_to_mods.size() / sizeof(float));
std::memcpy(inMsg.data(), cnn_to_mods.data(), cnn_to_mods.size());
const int desc_size = inMsg.size() /kps.size() ;

for (int img_num=0; img_num<kps.size(); img_num++)
{
std::vector<float> curr_desc(desc_size);
int offset = img_num*desc_size;
for (int i = 0; i < desc_size; ++i) {
const float v1 = inMsg[i+offset];
curr_desc[i] = v1;
}
out.push_back(curr_desc);
}
socket.close();
}
} else {
int num_batches = (kp_size / batch_size) + 1;
AffineRegionVector::const_iterator startIter(kps.cbegin());
AffineRegionVector::const_iterator endIter(kps.cbegin());
int num_done = 0;
int to_add = 0;
for (int bi = 0; bi < num_batches; bi++) {
num_done += batch_size;
if (num_done > kp_size) {
num_done -= batch_size;
to_add = kp_size - num_done;
} else {
to_add = batch_size;
}
std::advance(endIter, to_add);

AffineRegionVector current_kps(startIter, endIter);
std::vector<std::vector<float> > current_out = DescribeWithZmq(par,
current_kps,
temp_img1);
std::move(current_out.begin(), current_out.end(), std::back_inserter(out));
std::advance(startIter, to_add);
}
out.push_back(curr_desc);
}
socket.close();
return out;
}

Expand Down Expand Up @@ -1318,7 +1352,7 @@ void ImageRepresentation::LoadRegions(std::string fname) {
}
kpfile.close();
}
AffineRegionVector ImageRepresentation::PreLoadRegionsNPZ(std::string fname) {
AffineRegionVector ImageRepresentation::PreLoadRegionsNPZ(std::string fname) {
cnpy::npz_t my_npz = cnpy::npz_load(fname);
std::vector<std::string> keys;
bool A_is_here = false;
Expand Down

0 comments on commit c30204a

Please sign in to comment.