Skip to content

Commit

Permalink
add test_batch in Seq2Seq Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode committed Mar 29, 2019
1 parent 7981756 commit ab5865c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion 4-2.Seq2Seq(Attention)/Seq2Seq(Attention)-Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num
optimizer.step()

# Test
predict, trained_attn = model(input_batch, hidden, output_batch)
test_batch = [np.eye(n_class)[[word_dict[n] for n in 'SPPPP']]]
test_batch = Variable(torch.Tensor(test_batch))
predict, trained_attn = model(input_batch, hidden, test_batch)
predict = predict.data.max(1, keepdim=True)[1]
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])

Expand Down

0 comments on commit ab5865c

Please sign in to comment.