Skip to content

Commit 7704dec

Browse files
committed
add dmmiller612 greedy deecoder for colab
1 parent 2eb3317 commit 7704dec

File tree

3 files changed

+371
-42
lines changed

3 files changed

+371
-42
lines changed

5-1.Transformer/Transformer(Greedy_decoder)-Torch.py

+24-41
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
'''
2-
code by Tae Hwan Jung(Jeff Jung) @graykode
2+
code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612
33
Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
44
https://github.com/JayParks/transformer
55
'''
@@ -32,14 +32,12 @@
3232
n_layers = 6 # number of Encoder of Decoder Layer
3333
n_heads = 8 # number of heads in Multi-Head Attention
3434

35-
3635
def make_batch(sentences):
3736
input_batch = [[src_vocab[n] for n in sentences[0].split()]]
3837
output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]
3938
target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]
4039
return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch))
4140

42-
4341
def get_sinusoid_encoding_table(n_position, d_model):
4442
def cal_angle(position, hid_idx):
4543
return position / np.power(10000, 2 * (hid_idx // 2) / d_model)
@@ -51,14 +49,12 @@ def get_posi_angle_vec(position):
5149
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
5250
return torch.FloatTensor(sinusoid_table)
5351

54-
5552
def get_attn_pad_mask(seq_q, seq_k):
5653
batch_size, len_q = seq_q.size()
5754
batch_size, len_k = seq_k.size()
5855
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q)
5956
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
6057

61-
6258
class ScaledDotProductAttention(nn.Module):
6359

6460
def __init__(self):
@@ -72,7 +68,6 @@ def forward(self, Q, K, V, attn_mask=None):
7268
context = torch.matmul(attn, V)
7369
return context, attn
7470

75-
7671
class MultiHeadAttention(nn.Module):
7772

7873
def __init__(self):
@@ -97,7 +92,6 @@ def forward(self, Q, K, V, attn_mask=None):
9792
output = nn.Linear(n_heads * d_v, d_model)(context)
9893
return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]
9994

100-
10195
class PoswiseFeedForwardNet(nn.Module):
10296

10397
def __init__(self):
@@ -111,7 +105,6 @@ def forward(self, inputs):
111105
output = self.conv2(output).transpose(1, 2)
112106
return nn.LayerNorm(d_model)(output + residual)
113107

114-
115108
class EncoderLayer(nn.Module):
116109

117110
def __init__(self):
@@ -124,7 +117,6 @@ def forward(self, enc_inputs):
124117
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
125118
return enc_outputs, attn
126119

127-
128120
class DecoderLayer(nn.Module):
129121

130122
def __init__(self):
@@ -139,7 +131,6 @@ def forward(self, dec_inputs, enc_outputs, enc_attn_mask, dec_attn_mask=None):
139131
dec_outputs = self.pos_ffn(dec_outputs)
140132
return dec_outputs, dec_self_attn, dec_enc_attn
141133

142-
143134
class Encoder(nn.Module):
144135

145136
def __init__(self):
@@ -156,7 +147,6 @@ def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len]
156147
enc_self_attns.append(enc_self_attn)
157148
return enc_outputs, enc_self_attns
158149

159-
160150
class Decoder(nn.Module):
161151

162152
def __init__(self):
@@ -178,7 +168,6 @@ def forward(self, dec_inputs, enc_inputs, enc_outputs, dec_attn_mask=None): # de
178168
dec_enc_attns.append(dec_enc_attn)
179169
return dec_outputs, dec_self_attns, dec_enc_attns
180170

181-
182171
class Transformer(nn.Module):
183172

184173
def __init__(self):
@@ -193,7 +182,6 @@ def forward(self, enc_inputs, dec_inputs, decoder_mask=None):
193182
dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]
194183
return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
195184

196-
197185
def greedy_decoder(model, enc_input, start_symbol):
198186
"""
199187
For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
@@ -218,7 +206,6 @@ def greedy_decoder(model, enc_input, start_symbol):
218206
next_symbol = next_word[0]
219207
return dec_input
220208

221-
222209
def showgraph(attn):
223210
attn = attn[-1].squeeze(0)[0]
224211
attn = attn.squeeze(0).data.numpy()
@@ -229,36 +216,32 @@ def showgraph(attn):
229216
ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14})
230217
plt.show()
231218

219+
model = Transformer()
232220

233-
if __name__ == '__main__':
234-
235-
model = Transformer()
236-
237-
criterion = nn.CrossEntropyLoss()
238-
optimizer = optim.Adam(model.parameters(), lr=0.001)
239-
240-
for epoch in range(100):
241-
optimizer.zero_grad()
242-
enc_inputs, dec_inputs, target_batch = make_batch(sentences)
243-
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
244-
loss = criterion(outputs, target_batch.contiguous().view(-1))
245-
if (epoch + 1) % 20 == 0:
246-
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
247-
loss.backward()
248-
optimizer.step()
221+
criterion = nn.CrossEntropyLoss()
222+
optimizer = optim.Adam(model.parameters(), lr=0.001)
249223

224+
for epoch in range(100):
225+
optimizer.zero_grad()
226+
enc_inputs, dec_inputs, target_batch = make_batch(sentences)
227+
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
228+
loss = criterion(outputs, target_batch.contiguous().view(-1))
229+
if (epoch + 1) % 20 == 0:
230+
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
231+
loss.backward()
232+
optimizer.step()
250233

251-
# Test
252-
greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=4)
253-
predict, _, _, _ = model(enc_inputs, greedy_dec_input)
254-
predict = predict.data.max(1, keepdim=True)[1]
255-
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])
234+
# Test
235+
greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=4)
236+
predict, _, _, _ = model(enc_inputs, greedy_dec_input)
237+
predict = predict.data.max(1, keepdim=True)[1]
238+
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])
256239

257-
print('first head of last state enc_self_attns')
258-
showgraph(enc_self_attns)
240+
print('first head of last state enc_self_attns')
241+
showgraph(enc_self_attns)
259242

260-
print('first head of last state dec_self_attns')
261-
showgraph(dec_self_attns)
243+
print('first head of last state dec_self_attns')
244+
showgraph(dec_self_attns)
262245

263-
print('first head of last state dec_enc_attns')
264-
showgraph(dec_enc_attns)
246+
print('first head of last state dec_enc_attns')
247+
showgraph(dec_enc_attns)

0 commit comments

Comments
 (0)