Skip to content

Commit 3b3a80d

Browse files
committed
fixed wrong comment
1 parent 664bf08 commit 3b3a80d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

4-1.Seq2Seq/Seq2Seq-Torch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def forward(self, enc_input, enc_hidden, dec_input):
7575
# output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]
7676
# target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot
7777
output = model(input_batch, hidden, output_batch)
78-
# output : [max_len+1, batch_size, num_directions(=1) * n_hidden]
79-
output = output.transpose(0, 1) # [batch_size, max_len+1(=6), num_directions(=1) * n_hidden]
78+
# output : [max_len+1, batch_size, n_class]
79+
output = output.transpose(0, 1) # [batch_size, max_len+1(=6), n_class]
8080
loss = 0
8181
for i in range(0, len(target_batch)):
82-
# output[i] : [max_len+1, num_directions(=1) * n_hidden, target_batch[i] : max_len+1]
82+
# output[i] : [max_len+1, n_class, target_batch[i] : max_len+1]
8383
loss += criterion(output[i], target_batch[i])
8484
if (epoch + 1) % 1000 == 0:
8585
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))

0 commit comments

Comments
 (0)