Skip to content

Commit

Permalink
net convergence at 140
Browse files Browse the repository at this point in the history
  • Loading branch information
xmlyqing00 committed Nov 29, 2018
1 parent 54c5972 commit 96454ff
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ add_executable(build-dataset ${EVALUATOR_PATH}/build_dataset.cpp)
target_link_libraries(build-dataset ${OpenCV_LIBRARIES})

add_executable(train-evaluator ${EVALUATOR_PATH}/train_evaluator.cpp)
target_link_libraries(train-evaluator ${TORCH_LIBRARIES} compatibility-net compatibility-dataset)
target_link_libraries(train-evaluator ${TORCH_LIBRARIES} compatibility-net compatibility-dataset utils)

# add-noise
add_executable(add-noise ${SRC_ROOT_PATH}/add_noise.cpp)
Expand Down
1 change: 1 addition & 0 deletions include/train_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>
#include <torch/torch.h>

#include <utils.h>
#include <compatibility_net.h>
#include <compatibility_dataset.h>

Expand Down
7 changes: 7 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#ifndef UTILS_H
#define UTILS_H

#include <chrono>
#include <ctime>
#include <string>
#include <opencv2/opencv.hpp>

using namespace std;

enum class PuzzleType {
STRIPES,
SQUARES
Expand All @@ -18,4 +23,6 @@ cv::Mat merge_imgs(const cv::Mat & in_img0, const cv::Mat & in_img1);

bool cross_seam(const cv::Rect & bbox, int seam_x);

void print_timestamp();

#endif
23 changes: 13 additions & 10 deletions src/evaluator/train_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ void test( CompatibilityNet & comp_net,
target = squeeze(target, /*dim*/1);

Tensor output = comp_net.forward(data);
cout << target[0] << endl;
cout << output[0] << endl;

test_loss += nll_loss(output, target, symbols_w, Reduction::Sum).template item<float>();
auto pred = output.argmax(1);
correct_n += pred.eq(target).sum().template item<int64_t>();
Expand All @@ -80,10 +79,10 @@ int main(int argc, char ** argv) {
int epochs = 500;
int batch_size = 128;
double lr = 1e-2;
double alpha = 0.9;
double momentum = 0.9;

// Parse command line parameters
const string opt_str = "e:b:l:a:";
const string opt_str = "e:b:l:m:";
int opt = getopt(argc, argv, opt_str.c_str());

while (opt != -1) {
Expand All @@ -97,8 +96,8 @@ int main(int argc, char ** argv) {
case 'l':
lr = atof(optarg);
break;
case 'a':
alpha = atof(optarg);
case 'm':
momentum = atof(optarg);
break;
}

Expand All @@ -108,7 +107,7 @@ int main(int argc, char ** argv) {
cout << "Total epochs: \t" << epochs << endl;
cout << "Batch size: \t" << batch_size << endl;
cout << "Learning rate: \t" << lr << endl;
cout << "Alpha: \t" << alpha << endl;
cout << "Momentum: \t" << momentum << endl;
cout << endl;

DeviceType device_type;
Expand Down Expand Up @@ -146,9 +145,13 @@ int main(int argc, char ** argv) {
dataloader_options
);

optim::RMSprop optimizer(
// optim::RMSprop optimizer(
// comp_net.parameters(),
// optim::RMSpropOptions(lr).alpha(alpha)
// );
optim::SGD optimizer(
comp_net.parameters(),
optim::RMSpropOptions(lr).alpha(alpha)
optim::SGDOptions(lr).momentum(momentum)
);

if (access(saved_model_folder.c_str(), 0) == -1) {
Expand All @@ -162,9 +165,9 @@ int main(int argc, char ** argv) {
symbols_w = symbols_w.to(device);

for (int epoch = 1; epoch <= epochs; epoch++) {
print_timestamp();
train(epoch, comp_net, *train_loader, optimizer, device);
test(comp_net, *test_loader, device);
cout << endl;
}

return 0;
Expand Down
9 changes: 9 additions & 0 deletions src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,13 @@ bool cross_seam(const cv::Rect & bbox, int seam_x) {
return false;
}

}

void print_timestamp() {

auto now = chrono::system_clock::now();
time_t cur_time = chrono::system_clock::to_time_t(now);

cout << endl << "Current timestamp: " << ctime(&cur_time) << endl;

}

0 comments on commit 96454ff

Please sign in to comment.