1
1
'''
2
- code by Tae Hwan Jung(Jeff Jung) @graykode
2
+ code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612
3
3
Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
4
4
https://github.com/JayParks/transformer
5
5
'''
32
32
n_layers = 6 # number of Encoder of Decoder Layer
33
33
n_heads = 8 # number of heads in Multi-Head Attention
34
34
35
-
36
35
def make_batch (sentences ):
37
36
input_batch = [[src_vocab [n ] for n in sentences [0 ].split ()]]
38
37
output_batch = [[tgt_vocab [n ] for n in sentences [1 ].split ()]]
39
38
target_batch = [[tgt_vocab [n ] for n in sentences [2 ].split ()]]
40
39
return Variable (torch .LongTensor (input_batch )), Variable (torch .LongTensor (output_batch )), Variable (torch .LongTensor (target_batch ))
41
40
42
-
43
41
def get_sinusoid_encoding_table (n_position , d_model ):
44
42
def cal_angle (position , hid_idx ):
45
43
return position / np .power (10000 , 2 * (hid_idx // 2 ) / d_model )
@@ -51,14 +49,12 @@ def get_posi_angle_vec(position):
51
49
sinusoid_table [:, 1 ::2 ] = np .cos (sinusoid_table [:, 1 ::2 ]) # dim 2i+1
52
50
return torch .FloatTensor (sinusoid_table )
53
51
54
-
55
52
def get_attn_pad_mask (seq_q , seq_k ):
56
53
batch_size , len_q = seq_q .size ()
57
54
batch_size , len_k = seq_k .size ()
58
55
pad_attn_mask = seq_k .data .eq (0 ).unsqueeze (1 ) # batch_size x 1 x len_k(=len_q)
59
56
return pad_attn_mask .expand (batch_size , len_q , len_k ) # batch_size x len_q x len_k
60
57
61
-
62
58
class ScaledDotProductAttention (nn .Module ):
63
59
64
60
def __init__ (self ):
@@ -72,7 +68,6 @@ def forward(self, Q, K, V, attn_mask=None):
72
68
context = torch .matmul (attn , V )
73
69
return context , attn
74
70
75
-
76
71
class MultiHeadAttention (nn .Module ):
77
72
78
73
def __init__ (self ):
@@ -97,7 +92,6 @@ def forward(self, Q, K, V, attn_mask=None):
97
92
output = nn .Linear (n_heads * d_v , d_model )(context )
98
93
return nn .LayerNorm (d_model )(output + residual ), attn # output: [batch_size x len_q x d_model]
99
94
100
-
101
95
class PoswiseFeedForwardNet (nn .Module ):
102
96
103
97
def __init__ (self ):
@@ -111,7 +105,6 @@ def forward(self, inputs):
111
105
output = self .conv2 (output ).transpose (1 , 2 )
112
106
return nn .LayerNorm (d_model )(output + residual )
113
107
114
-
115
108
class EncoderLayer (nn .Module ):
116
109
117
110
def __init__ (self ):
@@ -124,7 +117,6 @@ def forward(self, enc_inputs):
124
117
enc_outputs = self .pos_ffn (enc_outputs ) # enc_outputs: [batch_size x len_q x d_model]
125
118
return enc_outputs , attn
126
119
127
-
128
120
class DecoderLayer (nn .Module ):
129
121
130
122
def __init__ (self ):
@@ -139,7 +131,6 @@ def forward(self, dec_inputs, enc_outputs, enc_attn_mask, dec_attn_mask=None):
139
131
dec_outputs = self .pos_ffn (dec_outputs )
140
132
return dec_outputs , dec_self_attn , dec_enc_attn
141
133
142
-
143
134
class Encoder (nn .Module ):
144
135
145
136
def __init__ (self ):
@@ -156,7 +147,6 @@ def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len]
156
147
enc_self_attns .append (enc_self_attn )
157
148
return enc_outputs , enc_self_attns
158
149
159
-
160
150
class Decoder (nn .Module ):
161
151
162
152
def __init__ (self ):
@@ -178,7 +168,6 @@ def forward(self, dec_inputs, enc_inputs, enc_outputs, dec_attn_mask=None): # de
178
168
dec_enc_attns .append (dec_enc_attn )
179
169
return dec_outputs , dec_self_attns , dec_enc_attns
180
170
181
-
182
171
class Transformer (nn .Module ):
183
172
184
173
def __init__ (self ):
@@ -193,7 +182,6 @@ def forward(self, enc_inputs, dec_inputs, decoder_mask=None):
193
182
dec_logits = self .projection (dec_outputs ) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]
194
183
return dec_logits .view (- 1 , dec_logits .size (- 1 )), enc_self_attns , dec_self_attns , dec_enc_attns
195
184
196
-
197
185
def greedy_decoder (model , enc_input , start_symbol ):
198
186
"""
199
187
For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
@@ -218,7 +206,6 @@ def greedy_decoder(model, enc_input, start_symbol):
218
206
next_symbol = next_word [0 ]
219
207
return dec_input
220
208
221
-
222
209
def showgraph (attn ):
223
210
attn = attn [- 1 ].squeeze (0 )[0 ]
224
211
attn = attn .squeeze (0 ).data .numpy ()
@@ -229,36 +216,32 @@ def showgraph(attn):
229
216
ax .set_yticklabels (['' ]+ sentences [2 ].split (), fontdict = {'fontsize' : 14 })
230
217
plt .show ()
231
218
219
+ model = Transformer ()
232
220
233
- if __name__ == '__main__' :
234
-
235
- model = Transformer ()
236
-
237
- criterion = nn .CrossEntropyLoss ()
238
- optimizer = optim .Adam (model .parameters (), lr = 0.001 )
239
-
240
- for epoch in range (100 ):
241
- optimizer .zero_grad ()
242
- enc_inputs , dec_inputs , target_batch = make_batch (sentences )
243
- outputs , enc_self_attns , dec_self_attns , dec_enc_attns = model (enc_inputs , dec_inputs )
244
- loss = criterion (outputs , target_batch .contiguous ().view (- 1 ))
245
- if (epoch + 1 ) % 20 == 0 :
246
- print ('Epoch:' , '%04d' % (epoch + 1 ), 'cost =' , '{:.6f}' .format (loss ))
247
- loss .backward ()
248
- optimizer .step ()
221
+ criterion = nn .CrossEntropyLoss ()
222
+ optimizer = optim .Adam (model .parameters (), lr = 0.001 )
249
223
224
+ for epoch in range (100 ):
225
+ optimizer .zero_grad ()
226
+ enc_inputs , dec_inputs , target_batch = make_batch (sentences )
227
+ outputs , enc_self_attns , dec_self_attns , dec_enc_attns = model (enc_inputs , dec_inputs )
228
+ loss = criterion (outputs , target_batch .contiguous ().view (- 1 ))
229
+ if (epoch + 1 ) % 20 == 0 :
230
+ print ('Epoch:' , '%04d' % (epoch + 1 ), 'cost =' , '{:.6f}' .format (loss ))
231
+ loss .backward ()
232
+ optimizer .step ()
250
233
251
- # Test
252
- greedy_dec_input = greedy_decoder (model , enc_inputs , start_symbol = 4 )
253
- predict , _ , _ , _ = model (enc_inputs , greedy_dec_input )
254
- predict = predict .data .max (1 , keepdim = True )[1 ]
255
- print (sentences [0 ], '->' , [number_dict [n .item ()] for n in predict .squeeze ()])
234
+ # Test
235
+ greedy_dec_input = greedy_decoder (model , enc_inputs , start_symbol = 4 )
236
+ predict , _ , _ , _ = model (enc_inputs , greedy_dec_input )
237
+ predict = predict .data .max (1 , keepdim = True )[1 ]
238
+ print (sentences [0 ], '->' , [number_dict [n .item ()] for n in predict .squeeze ()])
256
239
257
- print ('first head of last state enc_self_attns' )
258
- showgraph (enc_self_attns )
240
+ print ('first head of last state enc_self_attns' )
241
+ showgraph (enc_self_attns )
259
242
260
- print ('first head of last state dec_self_attns' )
261
- showgraph (dec_self_attns )
243
+ print ('first head of last state dec_self_attns' )
244
+ showgraph (dec_self_attns )
262
245
263
- print ('first head of last state dec_enc_attns' )
264
- showgraph (dec_enc_attns )
246
+ print ('first head of last state dec_enc_attns' )
247
+ showgraph (dec_enc_attns )
0 commit comments