13
13
from torch .autograd import Variable
14
14
15
15
# BERT Parameters
16
- maxlen = 512
16
+ maxlen = 30
17
17
batch_size = 6
18
- max_pred = 20 # max tokens of prediction
19
- n_layers = 12
18
+ max_pred = 5 # max tokens of prediction
19
+ n_layers = 6
20
20
n_heads = 12
21
21
d_model = 768
22
22
d_ff = 768 * 4 # 4*d_model, FeedForward dimension
44
44
arr = [word_dict [s ] for s in sentence .split ()]
45
45
token_list .append (arr )
46
46
47
+ # sample IsNext and NotNext to be same in small batch size
48
+ def make_batch ():
49
+ batch = []
50
+ positive = negative = 0
51
+ while positive != batch_size / 2 or negative != batch_size / 2 :
52
+ tokens_a_index , tokens_b_index = randrange (len (sentences )), randrange (len (sentences )) # sample random index in sentences
53
+ tokens_a , tokens_b = token_list [tokens_a_index ], token_list [tokens_b_index ]
54
+ input_ids = [word_dict ['[CLS]' ]] + tokens_a + [word_dict ['[SEP]' ]] + tokens_b + [word_dict ['[SEP]' ]]
55
+ segment_ids = [0 ] * (1 + len (tokens_a ) + 1 ) + [1 ] * (len (tokens_b ) + 1 )
56
+
57
+ # MASK LM
58
+ n_pred = min (max_pred , max (1 , int (round (len (input_ids ) * 0.15 )))) # 15 % of tokens in one sentence
59
+ cand_maked_pos = [i for i , token in enumerate (input_ids )
60
+ if token != word_dict ['[CLS]' ] and token != word_dict ['[SEP]' ]]
61
+ shuffle (cand_maked_pos )
62
+ masked_tokens , masked_pos = [], []
63
+ for pos in cand_maked_pos [:n_pred ]:
64
+ masked_pos .append (pos )
65
+ masked_tokens .append (input_ids [pos ])
66
+ if random () < 0.8 : # 80%
67
+ input_ids [pos ] = word_dict ['[MASK]' ] # make mask
68
+ elif random () < 0.5 : # 10%
69
+ index = randint (0 , vocab_size - 1 ) # random index in vocabulary
70
+ input_ids [pos ] = word_dict [number_dict [index ]] # replace
71
+
72
+ # Zero Paddings
73
+ n_pad = maxlen - len (input_ids )
74
+ input_ids .extend ([0 ] * n_pad )
75
+ segment_ids .extend ([0 ] * n_pad )
76
+
77
+ # Zero Padding (100% - 15%) tokens
78
+ if max_pred > n_pred :
79
+ n_pad = max_pred - n_pred
80
+ masked_tokens .extend ([0 ] * n_pad )
81
+ masked_pos .extend ([0 ] * n_pad )
82
+
83
+ if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2 :
84
+ batch .append ([input_ids , segment_ids , masked_tokens , masked_pos , True ]) # IsNext
85
+ positive += 1
86
+ elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2 :
87
+ batch .append ([input_ids , segment_ids , masked_tokens , masked_pos , False ]) # NotNext
88
+ negative += 1
89
+ return batch
90
+ # Proprecessing Finished
91
+
92
+ def get_attn_pad_mask (seq_q , seq_k ):
93
+ batch_size , len_q = seq_q .size ()
94
+ batch_size , len_k = seq_k .size ()
95
+ # eq(zero) is PAD token
96
+ pad_attn_mask = seq_k .data .eq (0 ).unsqueeze (1 ) # batch_size x 1 x len_k(=len_q), one is masking
97
+ return pad_attn_mask .expand (batch_size , len_q , len_k ) # batch_size x len_q x len_k
98
+
47
99
def gelu (x ):
48
100
"Implementation of the gelu activation function by Hugging Face"
49
101
return x * 0.5 * (1.0 + torch .erf (x / math .sqrt (2.0 )))
@@ -55,21 +107,21 @@ def __init__(self):
55
107
self .pos_embed = nn .Embedding (maxlen , d_model ) # position embedding
56
108
self .seg_embed = nn .Embedding (n_segments , d_model ) # segment(token type) embedding
57
109
self .norm = nn .LayerNorm (d_model )
58
- self .drop = nn .Dropout (0.1 )
59
110
60
111
def forward (self , x , seg ):
61
112
seq_len = x .size (1 )
62
113
pos = torch .arange (seq_len , dtype = torch .long )
63
114
pos = pos .unsqueeze (0 ).expand_as (x ) # (seq_len,) -> (batch_size, seq_len)
64
115
embedding = self .tok_embed (x ) + self .pos_embed (pos ) + self .seg_embed (seg )
65
- return self .drop ( self . norm (embedding ) )
116
+ return self .norm (embedding )
66
117
67
118
class ScaledDotProductAttention (nn .Module ):
68
119
def __init__ (self ):
69
120
super (ScaledDotProductAttention , self ).__init__ ()
70
121
71
- def forward (self , Q , K , V , attn_mask = None ):
122
+ def forward (self , Q , K , V , attn_mask ):
72
123
scores = torch .matmul (Q , K .transpose (- 1 , - 2 )) / np .sqrt (d_k ) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
124
+ scores .masked_fill_ (attn_mask , - 1e9 ) # Fills elements of self tensor with value where mask is one.
73
125
attn = nn .Softmax (dim = - 1 )(scores )
74
126
context = torch .matmul (attn , V )
75
127
return context , attn
@@ -80,19 +132,18 @@ def __init__(self):
80
132
self .W_Q = nn .Linear (d_model , d_k * n_heads )
81
133
self .W_K = nn .Linear (d_model , d_k * n_heads )
82
134
self .W_V = nn .Linear (d_model , d_v * n_heads )
83
-
84
- def forward (self , Q , K , V , attn_mask = None ):
135
+ def forward (self , Q , K , V , attn_mask ):
85
136
# q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
86
137
residual , batch_size = Q , Q .size (0 )
87
138
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
88
139
q_s = self .W_Q (Q ).view (batch_size , - 1 , n_heads , d_k ).transpose (1 ,2 ) # q_s: [batch_size x n_heads x len_q x d_k]
89
140
k_s = self .W_K (K ).view (batch_size , - 1 , n_heads , d_k ).transpose (1 ,2 ) # k_s: [batch_size x n_heads x len_k x d_k]
90
141
v_s = self .W_V (V ).view (batch_size , - 1 , n_heads , d_v ).transpose (1 ,2 ) # v_s: [batch_size x n_heads x len_k x d_v]
91
142
92
- if attn_mask is not None : # attn_mask : [batch_size x len_q x len_k]
93
- attn_mask = attn_mask . unsqueeze ( 1 ). repeat ( 1 , n_heads , 1 , 1 ) # attn_mask : [batch_size x n_heads x len_q x len_k]
143
+ attn_mask = attn_mask . unsqueeze ( 1 ). repeat ( 1 , n_heads , 1 , 1 ) # attn_mask : [batch_size x n_heads x len_q x len_k]
144
+
94
145
# context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
95
- context , attn = ScaledDotProductAttention ()(q_s , k_s , v_s , attn_mask = attn_mask )
146
+ context , attn = ScaledDotProductAttention ()(q_s , k_s , v_s , attn_mask )
96
147
context = context .transpose (1 , 2 ).contiguous ().view (batch_size , - 1 , n_heads * d_v ) # context: [batch_size x len_q x n_heads * d_v]
97
148
output = nn .Linear (n_heads * d_v , d_model )(context )
98
149
return nn .LayerNorm (d_model )(output + residual ), attn # output: [batch_size x len_q x d_model]
@@ -113,8 +164,8 @@ def __init__(self):
113
164
self .enc_self_attn = MultiHeadAttention ()
114
165
self .pos_ffn = PoswiseFeedForwardNet ()
115
166
116
- def forward (self , enc_inputs ):
117
- enc_outputs , attn = self .enc_self_attn (enc_inputs , enc_inputs , enc_inputs ) # enc_inputs to same Q,K,V
167
+ def forward (self , enc_inputs , enc_self_attn_mask ):
168
+ enc_outputs , attn = self .enc_self_attn (enc_inputs , enc_inputs , enc_inputs , enc_self_attn_mask ) # enc_inputs to same Q,K,V
118
169
enc_outputs = self .pos_ffn (enc_outputs ) # enc_outputs: [batch_size x len_q x d_model]
119
170
return enc_outputs , attn
120
171
@@ -138,9 +189,11 @@ def __init__(self):
138
189
139
190
def forward (self , input_ids , segment_ids , masked_pos ):
140
191
output = self .embedding (input_ids , segment_ids )
192
+ enc_self_attn_mask = get_attn_pad_mask (input_ids , input_ids )
141
193
for layer in self .layers :
142
- output , enc_self_attn = layer (output )
194
+ output , enc_self_attn = layer (output , enc_self_attn_mask )
143
195
# output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
196
+ # it will be decided by first token(CLS)
144
197
h_pooled = self .activ1 (self .fc (output [:, 0 ])) # [batch_size, d_model]
145
198
logits_clsf = self .classifier (h_pooled ) # [batch_size, 2]
146
199
@@ -151,80 +204,35 @@ def forward(self, input_ids, segment_ids, masked_pos):
151
204
152
205
return logits_lm , logits_clsf
153
206
154
- # sample IsNext and NotNext to be same in small batch size
155
- def make_batch ():
156
- batch = []
157
- positive = negative = 0
158
- while positive != batch_size / 2 or negative != batch_size / 2 :
159
- tokens_a_index , tokens_b_index = randrange (len (sentences )), randrange (len (sentences )) # sample random index in sentences
160
- tokens_a , tokens_b = token_list [tokens_a_index ], token_list [tokens_b_index ]
161
- input_ids = [word_dict ['[CLS]' ]] + tokens_a + [word_dict ['[SEP]' ]] + tokens_b + [word_dict ['[SEP]' ]]
162
- segment_ids = [0 ] * (1 + len (tokens_a ) + 1 ) + [1 ] * (len (tokens_b ) + 1 )
163
-
164
- # MASK LM
165
- n_pred = min (max_pred , max (1 , int (round (len (input_ids ) * 0.15 )))) # 15 % of tokens in one sentence
166
- cand_maked_pos = [i for i , token in enumerate (input_ids )]
167
- shuffle (cand_maked_pos )
168
- masked_tokens , masked_pos = [], []
169
- for pos in cand_maked_pos [:n_pred ]:
170
- masked_pos .append (pos )
171
- masked_tokens .append (input_ids [pos ])
172
- if random () < 0.8 : # 80%
173
- input_ids [pos ] = word_dict ['[MASK]' ] # make mask
174
- elif random () < 0.5 : # 10%
175
- index = randint (0 , vocab_size - 1 ) # random index in vocabulary
176
- input_ids [pos ] = word_dict [number_dict [index ]] # replace
177
-
178
- # Zero Paddings
179
- n_pad = maxlen - len (input_ids )
180
- input_ids .extend ([0 ] * n_pad )
181
- segment_ids .extend ([0 ] * n_pad )
182
-
183
- # Zero Padding (100% - 15%) tokens
184
- if max_pred > n_pred :
185
- n_pad = max_pred - n_pred
186
- masked_tokens .extend ([0 ] * n_pad )
187
- masked_pos .extend ([0 ] * n_pad )
188
-
189
- if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2 :
190
- batch .append ([input_ids , segment_ids , masked_tokens , masked_pos , True ]) # IsNext
191
- positive += 1
192
- elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2 :
193
- batch .append ([input_ids , segment_ids , masked_tokens , masked_pos , False ]) # NotNext
194
- negative += 1
195
- return batch
196
- # Proprecessing Finished
197
-
198
207
model = BERT ()
199
- criterion1 = nn .CrossEntropyLoss (reduction = 'none' )
200
- criterion2 = nn .CrossEntropyLoss ()
201
- optimizer = optim .Adam (model .parameters (), lr = 1e-4 )
208
+ criterion = nn .CrossEntropyLoss ()
209
+ optimizer = optim .Adam (model .parameters (), lr = 0.001 )
202
210
203
211
batch = make_batch ()
204
212
input_ids , segment_ids , masked_tokens , masked_pos , isNext = zip (* batch )
205
- input_ids = Variable (torch .LongTensor (input_ids ))
206
- segment_ids = Variable (torch .LongTensor (segment_ids ))
207
- masked_pos = Variable (torch .LongTensor (masked_pos ))
208
- masked_tokens = Variable (torch .LongTensor (masked_tokens ))
209
- isNext = Variable (torch .LongTensor (isNext ))
213
+ input_ids , segment_ids , masked_tokens , masked_pos , isNext = \
214
+ torch .LongTensor (input_ids ), torch .LongTensor (segment_ids ), torch .LongTensor (masked_tokens ), \
215
+ torch .LongTensor (masked_pos ), torch .LongTensor (isNext )
210
216
211
- for epoch in range (25 ):
217
+ for epoch in range (100 ):
212
218
optimizer .zero_grad ()
213
219
logits_lm , logits_clsf = model (input_ids , segment_ids , masked_pos )
214
- loss_lm = criterion1 (logits_lm .transpose (1 , 2 ), masked_tokens ) # for masked LM
220
+ loss_lm = criterion (logits_lm .transpose (1 , 2 ), masked_tokens ) # for masked LM
215
221
loss_lm = (loss_lm .float ()).mean ()
216
- loss_clsf = criterion2 (logits_clsf , isNext ) # for sentence classification
222
+ loss_clsf = criterion (logits_clsf , isNext ) # for sentence classification
217
223
loss = loss_lm + loss_clsf
218
- print ('Epoch:' , '%04d' % (epoch + 1 ), 'cost =' , '{:.6f}' .format (loss ))
224
+ if (epoch + 1 ) % 10 == 0 :
225
+ print ('Epoch:' , '%04d' % (epoch + 1 ), 'cost =' , '{:.6f}' .format (loss ))
219
226
loss .backward ()
220
227
optimizer .step ()
221
228
222
229
# Predict mask tokens ans isNext
223
- input_ids , segment_ids , masked_tokens , masked_pos , isNext = make_batch () [0 ]
230
+ input_ids , segment_ids , masked_tokens , masked_pos , isNext = batch [0 ]
224
231
print (text )
225
232
print ([number_dict [w ] for w in input_ids if number_dict [w ] != '[PAD]' ])
226
233
227
- logits_lm , logits_clsf = model (Variable (torch .LongTensor ([input_ids ])), Variable (torch .LongTensor ([segment_ids ])), Variable (torch .LongTensor ([masked_pos ])))
234
+ logits_lm , logits_clsf = model (torch .LongTensor ([input_ids ]), \
235
+ torch .LongTensor ([segment_ids ]), torch .LongTensor ([masked_pos ]))
228
236
logits_lm = logits_lm .data .max (2 )[1 ][0 ].data .numpy ()
229
237
print ('masked tokens list : ' ,[pos for pos in masked_tokens if pos != 0 ])
230
238
print ('predict masked tokens list : ' ,[pos for pos in logits_lm if pos != 0 ])
0 commit comments