From ab5865c77abb2b029c5e4a3d7c180d0a1eed63b4 Mon Sep 17 00:00:00 2001 From: graykode Date: Fri, 29 Mar 2019 19:32:56 +0900 Subject: [PATCH] add test_batch in Seq2Seq Attention --- 4-2.Seq2Seq(Attention)/Seq2Seq(Attention)-Torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/4-2.Seq2Seq(Attention)/Seq2Seq(Attention)-Torch.py b/4-2.Seq2Seq(Attention)/Seq2Seq(Attention)-Torch.py index 54510ba..2bbe1af 100644 --- a/4-2.Seq2Seq(Attention)/Seq2Seq(Attention)-Torch.py +++ b/4-2.Seq2Seq(Attention)/Seq2Seq(Attention)-Torch.py @@ -105,7 +105,9 @@ def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num optimizer.step() # Test -predict, trained_attn = model(input_batch, hidden, output_batch) +test_batch = [np.eye(n_class)[[word_dict[n] for n in 'SPPPP']]] +test_batch = Variable(torch.Tensor(test_batch)) +predict, trained_attn = model(input_batch, hidden, test_batch) predict = predict.data.max(1, keepdim=True)[1] print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])