@@ -75,11 +75,11 @@ def forward(self, enc_input, enc_hidden, dec_input):
75
75
# output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]
76
76
# target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot
77
77
output = model (input_batch , hidden , output_batch )
78
- # output : [max_len+1, batch_size, num_directions(=1) * n_hidden ]
79
- output = output .transpose (0 , 1 ) # [batch_size, max_len+1(=6), num_directions(=1) * n_hidden ]
78
+ # output : [max_len+1, batch_size, n_class ]
79
+ output = output .transpose (0 , 1 ) # [batch_size, max_len+1(=6), n_class ]
80
80
loss = 0
81
81
for i in range (0 , len (target_batch )):
82
- # output[i] : [max_len+1, num_directions(=1) * n_hidden , target_batch[i] : max_len+1]
82
+ # output[i] : [max_len+1, n_class , target_batch[i] : max_len+1]
83
83
loss += criterion (output [i ], target_batch [i ])
84
84
if (epoch + 1 ) % 1000 == 0 :
85
85
print ('Epoch:' , '%04d' % (epoch + 1 ), 'cost =' , '{:.6f}' .format (loss ))
0 commit comments