Skip to content

Commit

Permalink
add nll weight
Browse files Browse the repository at this point in the history
  • Loading branch information
yq committed Nov 29, 2018
1 parent 360f079 commit c5b75ec
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions include/train_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ using namespace std;

const int symbols_n = 63;
const string saved_model_folder = "data/saved_models/";
Tensor symbols_w = torch::empty({symbols_n + 1}, kFloat32);

#endif
2 changes: 1 addition & 1 deletion src/evaluator/compatibility_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Tensor CompatibilityNet::forward(Tensor x) {
x = avg_pool2d(x, 4);
x = x.view({-1, 256});
x = relu(fc1->forward(x));

x = fc2->forward(x);
x = log_softmax(x, 1);

return x;
Expand Down
10 changes: 8 additions & 2 deletions src/evaluator/train_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void train( int epoch,

Tensor output = comp_net.forward(data);

Tensor loss = nll_loss(output, target);
Tensor loss = nll_loss(output, target, symbols_w);
optimizer.zero_grad();
loss.backward();
optimizer.step();
Expand Down Expand Up @@ -59,7 +59,7 @@ void test( CompatibilityNet & comp_net,

Tensor output = comp_net.forward(data);

test_loss += nll_loss(output, target, /*weight=*/{}, Reduction::Sum).template item<float>();
test_loss += nll_loss(output, target, symbols_w, Reduction::Sum).template item<float>();
auto pred = output.argmax(1);
correct_n += pred.eq(target).sum().template item<int64_t>();
total_n += batch.data.size(0);
Expand Down Expand Up @@ -154,9 +154,15 @@ int main(int argc, char ** argv) {
mkdir(saved_model_folder.c_str(), S_IRWXU|S_IRWXG|S_IROTH|S_IXOTH);
}

for (int i = 0; i < symbols_n; i++) {
symbols_w[i] = 1;
}
symbols_w[symbols_n] = 1.0 / symbols_n / (symbols_n - 1);

for (int epoch = 1; epoch <= epochs; epoch++) {
train(epoch, comp_net, *train_loader, optimizer, device);
test(comp_net, *test_loader, device);
cout << endl;
}

return 0;
Expand Down

0 comments on commit c5b75ec

Please sign in to comment.