@@ -52,48 +52,50 @@ def get_posi_angle_vec(position):
52
52
def get_attn_pad_mask (seq_q , seq_k ):
53
53
batch_size , len_q = seq_q .size ()
54
54
batch_size , len_k = seq_k .size ()
55
- pad_attn_mask = seq_k .data .eq (0 ).unsqueeze (1 ) # batch_size x 1 x len_k(=len_q)
55
+ # eq(zero) is PAD token
56
+ pad_attn_mask = seq_k .data .eq (0 ).unsqueeze (1 ) # batch_size x 1 x len_k(=len_q), one is masking
56
57
return pad_attn_mask .expand (batch_size , len_q , len_k ) # batch_size x len_q x len_k
57
58
58
- class ScaledDotProductAttention (nn .Module ):
59
+ def get_attn_subsequent_mask (seq ):
60
+ attn_shape = [seq .size (0 ), seq .size (1 ), seq .size (1 )]
61
+ subsequent_mask = np .triu (np .ones (attn_shape ), k = 1 )
62
+ subsequent_mask = torch .from_numpy (subsequent_mask ).byte ()
63
+ return subsequent_mask
59
64
65
+ class ScaledDotProductAttention (nn .Module ):
60
66
def __init__ (self ):
61
67
super (ScaledDotProductAttention , self ).__init__ ()
62
68
63
- def forward (self , Q , K , V , attn_mask = None ):
69
+ def forward (self , Q , K , V , attn_mask ):
64
70
scores = torch .matmul (Q , K .transpose (- 1 , - 2 )) / np .sqrt (d_k ) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
65
- if attn_mask is not None :
66
- scores .masked_fill_ (attn_mask , - 1e9 )
71
+ scores .masked_fill_ (attn_mask , - 1e9 ) # Fills elements of self tensor with value where mask is one.
67
72
attn = nn .Softmax (dim = - 1 )(scores )
68
73
context = torch .matmul (attn , V )
69
74
return context , attn
70
75
71
76
class MultiHeadAttention (nn .Module ):
72
-
73
77
def __init__ (self ):
74
78
super (MultiHeadAttention , self ).__init__ ()
75
79
self .W_Q = nn .Linear (d_model , d_k * n_heads )
76
80
self .W_K = nn .Linear (d_model , d_k * n_heads )
77
81
self .W_V = nn .Linear (d_model , d_v * n_heads )
78
-
79
- def forward (self , Q , K , V , attn_mask = None ):
82
+ def forward (self , Q , K , V , attn_mask ):
80
83
# q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
81
84
residual , batch_size = Q , Q .size (0 )
82
85
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
83
86
q_s = self .W_Q (Q ).view (batch_size , - 1 , n_heads , d_k ).transpose (1 ,2 ) # q_s: [batch_size x n_heads x len_q x d_k]
84
87
k_s = self .W_K (K ).view (batch_size , - 1 , n_heads , d_k ).transpose (1 ,2 ) # k_s: [batch_size x n_heads x len_k x d_k]
85
88
v_s = self .W_V (V ).view (batch_size , - 1 , n_heads , d_v ).transpose (1 ,2 ) # v_s: [batch_size x n_heads x len_k x d_v]
86
89
87
- if attn_mask is not None : # attn_mask : [batch_size x len_q x len_k]
88
- attn_mask = attn_mask . unsqueeze ( 1 ). repeat ( 1 , n_heads , 1 , 1 ) # attn_mask : [batch_size x n_heads x len_q x len_k]
90
+ attn_mask = attn_mask . unsqueeze ( 1 ). repeat ( 1 , n_heads , 1 , 1 ) # attn_mask : [batch_size x n_heads x len_q x len_k]
91
+
89
92
# context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
90
- context , attn = ScaledDotProductAttention ()(q_s , k_s , v_s , attn_mask = attn_mask )
93
+ context , attn = ScaledDotProductAttention ()(q_s , k_s , v_s , attn_mask )
91
94
context = context .transpose (1 , 2 ).contiguous ().view (batch_size , - 1 , n_heads * d_v ) # context: [batch_size x len_q x n_heads * d_v]
92
95
output = nn .Linear (n_heads * d_v , d_model )(context )
93
96
return nn .LayerNorm (d_model )(output + residual ), attn # output: [batch_size x len_q x d_model]
94
97
95
98
class PoswiseFeedForwardNet (nn .Module ):
96
-
97
99
def __init__ (self ):
98
100
super (PoswiseFeedForwardNet , self ).__init__ ()
99
101
self .conv1 = nn .Conv1d (in_channels = d_model , out_channels = d_ff , kernel_size = 1 )
@@ -106,33 +108,30 @@ def forward(self, inputs):
106
108
return nn .LayerNorm (d_model )(output + residual )
107
109
108
110
class EncoderLayer (nn .Module ):
109
-
110
111
def __init__ (self ):
111
112
super (EncoderLayer , self ).__init__ ()
112
113
self .enc_self_attn = MultiHeadAttention ()
113
114
self .pos_ffn = PoswiseFeedForwardNet ()
114
115
115
- def forward (self , enc_inputs ):
116
- enc_outputs , attn = self .enc_self_attn (enc_inputs , enc_inputs , enc_inputs ) # enc_inputs to same Q,K,V
116
+ def forward (self , enc_inputs , enc_self_attn_mask ):
117
+ enc_outputs , attn = self .enc_self_attn (enc_inputs , enc_inputs , enc_inputs , enc_self_attn_mask ) # enc_inputs to same Q,K,V
117
118
enc_outputs = self .pos_ffn (enc_outputs ) # enc_outputs: [batch_size x len_q x d_model]
118
119
return enc_outputs , attn
119
120
120
121
class DecoderLayer (nn .Module ):
121
-
122
122
def __init__ (self ):
123
123
super (DecoderLayer , self ).__init__ ()
124
124
self .dec_self_attn = MultiHeadAttention ()
125
125
self .dec_enc_attn = MultiHeadAttention ()
126
126
self .pos_ffn = PoswiseFeedForwardNet ()
127
127
128
- def forward (self , dec_inputs , enc_outputs , enc_attn_mask , dec_attn_mask = None ):
129
- dec_outputs , dec_self_attn = self .dec_self_attn (dec_inputs , dec_inputs , dec_inputs , dec_attn_mask )
130
- dec_outputs , dec_enc_attn = self .dec_enc_attn (dec_outputs , enc_outputs , enc_outputs , enc_attn_mask )
128
+ def forward (self , dec_inputs , enc_outputs , dec_self_attn_mask , dec_enc_attn_mask ):
129
+ dec_outputs , dec_self_attn = self .dec_self_attn (dec_inputs , dec_inputs , dec_inputs , dec_self_attn_mask )
130
+ dec_outputs , dec_enc_attn = self .dec_enc_attn (dec_outputs , enc_outputs , enc_outputs , dec_enc_attn_mask )
131
131
dec_outputs = self .pos_ffn (dec_outputs )
132
132
return dec_outputs , dec_self_attn , dec_enc_attn
133
133
134
134
class Encoder (nn .Module ):
135
-
136
135
def __init__ (self ):
137
136
super (Encoder , self ).__init__ ()
138
137
self .src_emb = nn .Embedding (src_vocab_size , d_model )
@@ -141,44 +140,44 @@ def __init__(self):
141
140
142
141
def forward (self , enc_inputs ): # enc_inputs : [batch_size x source_len]
143
142
enc_outputs = self .src_emb (enc_inputs ) + self .pos_emb (torch .LongTensor ([[1 ,2 ,3 ,4 ,5 ]]))
143
+ enc_self_attn_mask = get_attn_pad_mask (enc_inputs , enc_inputs )
144
144
enc_self_attns = []
145
145
for layer in self .layers :
146
- enc_outputs , enc_self_attn = layer (enc_outputs )
146
+ enc_outputs , enc_self_attn = layer (enc_outputs , enc_self_attn_mask )
147
147
enc_self_attns .append (enc_self_attn )
148
148
return enc_outputs , enc_self_attns
149
149
150
150
class Decoder (nn .Module ):
151
-
152
151
def __init__ (self ):
153
152
super (Decoder , self ).__init__ ()
154
153
self .tgt_emb = nn .Embedding (tgt_vocab_size , d_model )
155
154
self .pos_emb = nn .Embedding .from_pretrained (get_sinusoid_encoding_table (tgt_len + 1 , d_model ),freeze = True )
156
155
self .layers = nn .ModuleList ([DecoderLayer () for _ in range (n_layers )])
157
156
158
- def forward (self , dec_inputs , enc_inputs , enc_outputs , dec_attn_mask = None ): # dec_inputs : [batch_size x target_len]
157
+ def forward (self , dec_inputs , enc_inputs , enc_outputs ): # dec_inputs : [batch_size x target_len]
159
158
dec_outputs = self .tgt_emb (dec_inputs ) + self .pos_emb (torch .LongTensor ([[1 ,2 ,3 ,4 ,5 ]]))
160
- dec_enc_attn_pad_mask = get_attn_pad_mask (dec_inputs , enc_inputs )
159
+ dec_self_attn_pad_mask = get_attn_pad_mask (dec_inputs , dec_inputs )
160
+ dec_self_attn_subsequent_mask = get_attn_subsequent_mask (dec_inputs )
161
+ dec_self_attn_mask = torch .gt ((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask ), 0 )
162
+
163
+ dec_enc_attn_mask = get_attn_pad_mask (dec_inputs , enc_inputs )
161
164
162
165
dec_self_attns , dec_enc_attns = [], []
163
166
for layer in self .layers :
164
- dec_outputs , dec_self_attn , dec_enc_attn = layer (dec_outputs , enc_outputs ,
165
- enc_attn_mask = dec_enc_attn_pad_mask ,
166
- dec_attn_mask = dec_attn_mask )
167
+ dec_outputs , dec_self_attn , dec_enc_attn = layer (dec_outputs , enc_outputs , dec_self_attn_mask , dec_enc_attn_mask )
167
168
dec_self_attns .append (dec_self_attn )
168
169
dec_enc_attns .append (dec_enc_attn )
169
170
return dec_outputs , dec_self_attns , dec_enc_attns
170
171
171
172
class Transformer (nn .Module ):
172
-
173
173
def __init__ (self ):
174
174
super (Transformer , self ).__init__ ()
175
175
self .encoder = Encoder ()
176
176
self .decoder = Decoder ()
177
177
self .projection = nn .Linear (d_model , tgt_vocab_size , bias = False )
178
-
179
- def forward (self , enc_inputs , dec_inputs , decoder_mask = None ):
178
+ def forward (self , enc_inputs , dec_inputs ):
180
179
enc_outputs , enc_self_attns = self .encoder (enc_inputs )
181
- dec_outputs , dec_self_attns , dec_enc_attns = self .decoder (dec_inputs , enc_inputs , enc_outputs , decoder_mask )
180
+ dec_outputs , dec_self_attns , dec_enc_attns = self .decoder (dec_inputs , enc_inputs , enc_outputs )
182
181
dec_logits = self .projection (dec_outputs ) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]
183
182
return dec_logits .view (- 1 , dec_logits .size (- 1 )), enc_self_attns , dec_self_attns , dec_enc_attns
184
183
@@ -192,18 +191,16 @@ def greedy_decoder(model, enc_input, start_symbol):
192
191
:param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
193
192
:return: The target input
194
193
"""
195
- memory , attention = model .encoder (enc_input )
196
- dec_input = torch .ones (1 , 5 ).fill_ (0 ).type_as (enc_input .data )
197
- dec_mask = torch .from_numpy (np .triu (np .ones ((1 , 5 , 5 )), 1 ).astype ('uint8' )) == 0
194
+ enc_outputs , enc_self_attns = model .encoder (enc_input )
195
+ dec_input = torch .zeros (1 , 5 ).type_as (enc_input .data )
198
196
next_symbol = start_symbol
199
197
for i in range (0 , 5 ):
200
198
dec_input [0 ][i ] = next_symbol
201
- out = model .decoder (Variable (dec_input ), enc_input , memory , dec_mask )
202
- projected = model .projection (out [0 ])
203
- prob = projected .view (- 1 , projected .size (- 1 ))
204
- prob = prob .data .max (1 , keepdim = True )[1 ]
199
+ dec_outputs , _ , _ = model .decoder (dec_input , enc_input , enc_outputs )
200
+ projected = model .projection (dec_outputs )
201
+ prob = projected .squeeze (0 ).max (dim = - 1 , keepdim = False )[1 ]
205
202
next_word = prob .data [i ]
206
- next_symbol = next_word [ 0 ]
203
+ next_symbol = next_word . item ()
207
204
return dec_input
208
205
209
206
def showgraph (attn ):
@@ -232,7 +229,7 @@ def showgraph(attn):
232
229
optimizer .step ()
233
230
234
231
# Test
235
- greedy_dec_input = greedy_decoder (model , enc_inputs , start_symbol = 4 )
232
+ greedy_dec_input = greedy_decoder (model , enc_inputs , start_symbol = tgt_vocab [ "S" ] )
236
233
predict , _ , _ , _ = model (enc_inputs , greedy_dec_input )
237
234
predict = predict .data .max (1 , keepdim = True )[1 ]
238
235
print (sentences [0 ], '->' , [number_dict [n .item ()] for n in predict .squeeze ()])
0 commit comments