Skip to content

Commit d7aa35e

Browse files
committed
graykode#22 edited issue BERT
1 parent 53d5c28 commit d7aa35e

File tree

2 files changed

+380
-379
lines changed

2 files changed

+380
-379
lines changed

5-2.BERT/BERT-Torch.py

+80-72
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from torch.autograd import Variable
1414

1515
# BERT Parameters
16-
maxlen = 512
16+
maxlen = 30
1717
batch_size = 6
18-
max_pred = 20 # max tokens of prediction
19-
n_layers = 12
18+
max_pred = 5 # max tokens of prediction
19+
n_layers = 6
2020
n_heads = 12
2121
d_model = 768
2222
d_ff = 768*4 # 4*d_model, FeedForward dimension
@@ -44,6 +44,58 @@
4444
arr = [word_dict[s] for s in sentence.split()]
4545
token_list.append(arr)
4646

47+
# sample IsNext and NotNext to be same in small batch size
48+
def make_batch():
49+
batch = []
50+
positive = negative = 0
51+
while positive != batch_size/2 or negative != batch_size/2:
52+
tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
53+
tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]
54+
input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
55+
segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
56+
57+
# MASK LM
58+
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence
59+
cand_maked_pos = [i for i, token in enumerate(input_ids)
60+
if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
61+
shuffle(cand_maked_pos)
62+
masked_tokens, masked_pos = [], []
63+
for pos in cand_maked_pos[:n_pred]:
64+
masked_pos.append(pos)
65+
masked_tokens.append(input_ids[pos])
66+
if random() < 0.8: # 80%
67+
input_ids[pos] = word_dict['[MASK]'] # make mask
68+
elif random() < 0.5: # 10%
69+
index = randint(0, vocab_size - 1) # random index in vocabulary
70+
input_ids[pos] = word_dict[number_dict[index]] # replace
71+
72+
# Zero Paddings
73+
n_pad = maxlen - len(input_ids)
74+
input_ids.extend([0] * n_pad)
75+
segment_ids.extend([0] * n_pad)
76+
77+
# Zero Padding (100% - 15%) tokens
78+
if max_pred > n_pred:
79+
n_pad = max_pred - n_pred
80+
masked_tokens.extend([0] * n_pad)
81+
masked_pos.extend([0] * n_pad)
82+
83+
if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
84+
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
85+
positive += 1
86+
elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
87+
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
88+
negative += 1
89+
return batch
90+
# Proprecessing Finished
91+
92+
def get_attn_pad_mask(seq_q, seq_k):
93+
batch_size, len_q = seq_q.size()
94+
batch_size, len_k = seq_k.size()
95+
# eq(zero) is PAD token
96+
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
97+
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
98+
4799
def gelu(x):
48100
"Implementation of the gelu activation function by Hugging Face"
49101
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
@@ -55,21 +107,21 @@ def __init__(self):
55107
self.pos_embed = nn.Embedding(maxlen, d_model) # position embedding
56108
self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding
57109
self.norm = nn.LayerNorm(d_model)
58-
self.drop = nn.Dropout(0.1)
59110

60111
def forward(self, x, seg):
61112
seq_len = x.size(1)
62113
pos = torch.arange(seq_len, dtype=torch.long)
63114
pos = pos.unsqueeze(0).expand_as(x) # (seq_len,) -> (batch_size, seq_len)
64115
embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
65-
return self.drop(self.norm(embedding))
116+
return self.norm(embedding)
66117

67118
class ScaledDotProductAttention(nn.Module):
68119
def __init__(self):
69120
super(ScaledDotProductAttention, self).__init__()
70121

71-
def forward(self, Q, K, V, attn_mask=None):
122+
def forward(self, Q, K, V, attn_mask):
72123
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
124+
scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
73125
attn = nn.Softmax(dim=-1)(scores)
74126
context = torch.matmul(attn, V)
75127
return context, attn
@@ -80,19 +132,18 @@ def __init__(self):
80132
self.W_Q = nn.Linear(d_model, d_k * n_heads)
81133
self.W_K = nn.Linear(d_model, d_k * n_heads)
82134
self.W_V = nn.Linear(d_model, d_v * n_heads)
83-
84-
def forward(self, Q, K, V, attn_mask=None):
135+
def forward(self, Q, K, V, attn_mask):
85136
# q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
86137
residual, batch_size = Q, Q.size(0)
87138
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
88139
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k]
89140
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k]
90141
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v]
91142

92-
if attn_mask is not None: # attn_mask : [batch_size x len_q x len_k]
93-
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]
143+
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]
144+
94145
# context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
95-
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask=attn_mask)
146+
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
96147
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
97148
output = nn.Linear(n_heads * d_v, d_model)(context)
98149
return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]
@@ -113,8 +164,8 @@ def __init__(self):
113164
self.enc_self_attn = MultiHeadAttention()
114165
self.pos_ffn = PoswiseFeedForwardNet()
115166

116-
def forward(self, enc_inputs):
117-
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) # enc_inputs to same Q,K,V
167+
def forward(self, enc_inputs, enc_self_attn_mask):
168+
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
118169
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
119170
return enc_outputs, attn
120171

@@ -138,9 +189,11 @@ def __init__(self):
138189

139190
def forward(self, input_ids, segment_ids, masked_pos):
140191
output = self.embedding(input_ids, segment_ids)
192+
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
141193
for layer in self.layers:
142-
output, enc_self_attn = layer(output)
194+
output, enc_self_attn = layer(output, enc_self_attn_mask)
143195
# output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
196+
# it will be decided by first token(CLS)
144197
h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]
145198
logits_clsf = self.classifier(h_pooled) # [batch_size, 2]
146199

@@ -151,80 +204,35 @@ def forward(self, input_ids, segment_ids, masked_pos):
151204

152205
return logits_lm, logits_clsf
153206

154-
# sample IsNext and NotNext to be same in small batch size
155-
def make_batch():
156-
batch = []
157-
positive = negative = 0
158-
while positive != batch_size/2 or negative != batch_size/2:
159-
tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
160-
tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]
161-
input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]
162-
segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
163-
164-
# MASK LM
165-
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence
166-
cand_maked_pos = [i for i, token in enumerate(input_ids)]
167-
shuffle(cand_maked_pos)
168-
masked_tokens, masked_pos = [], []
169-
for pos in cand_maked_pos[:n_pred]:
170-
masked_pos.append(pos)
171-
masked_tokens.append(input_ids[pos])
172-
if random() < 0.8: # 80%
173-
input_ids[pos] = word_dict['[MASK]'] # make mask
174-
elif random() < 0.5: # 10%
175-
index = randint(0, vocab_size - 1) # random index in vocabulary
176-
input_ids[pos] = word_dict[number_dict[index]] # replace
177-
178-
# Zero Paddings
179-
n_pad = maxlen - len(input_ids)
180-
input_ids.extend([0] * n_pad)
181-
segment_ids.extend([0] * n_pad)
182-
183-
# Zero Padding (100% - 15%) tokens
184-
if max_pred > n_pred:
185-
n_pad = max_pred - n_pred
186-
masked_tokens.extend([0] * n_pad)
187-
masked_pos.extend([0] * n_pad)
188-
189-
if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
190-
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
191-
positive += 1
192-
elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
193-
batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
194-
negative += 1
195-
return batch
196-
# Proprecessing Finished
197-
198207
model = BERT()
199-
criterion1 = nn.CrossEntropyLoss(reduction='none')
200-
criterion2 = nn.CrossEntropyLoss()
201-
optimizer = optim.Adam(model.parameters(), lr=1e-4)
208+
criterion = nn.CrossEntropyLoss()
209+
optimizer = optim.Adam(model.parameters(), lr=0.001)
202210

203211
batch = make_batch()
204212
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
205-
input_ids = Variable(torch.LongTensor(input_ids))
206-
segment_ids = Variable(torch.LongTensor(segment_ids))
207-
masked_pos = Variable(torch.LongTensor(masked_pos))
208-
masked_tokens = Variable(torch.LongTensor(masked_tokens))
209-
isNext = Variable(torch.LongTensor(isNext))
213+
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
214+
torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
215+
torch.LongTensor(masked_pos), torch.LongTensor(isNext)
210216

211-
for epoch in range(25):
217+
for epoch in range(100):
212218
optimizer.zero_grad()
213219
logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
214-
loss_lm = criterion1(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
220+
loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
215221
loss_lm = (loss_lm.float()).mean()
216-
loss_clsf = criterion2(logits_clsf, isNext) # for sentence classification
222+
loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
217223
loss = loss_lm + loss_clsf
218-
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
224+
if (epoch + 1) % 10 == 0:
225+
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
219226
loss.backward()
220227
optimizer.step()
221228

222229
# Predict mask tokens ans isNext
223-
input_ids, segment_ids, masked_tokens, masked_pos, isNext = make_batch()[0]
230+
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]
224231
print(text)
225232
print([number_dict[w] for w in input_ids if number_dict[w] != '[PAD]'])
226233

227-
logits_lm, logits_clsf = model(Variable(torch.LongTensor([input_ids])), Variable(torch.LongTensor([segment_ids])), Variable(torch.LongTensor([masked_pos])))
234+
logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
235+
torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
228236
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
229237
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
230238
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

0 commit comments

Comments
 (0)