forked from graykode/nlp-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
215 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# code by Tae Hwan Jung(Jeff Jung) @graykode | ||
import tensorflow as tf | ||
import numpy as np | ||
|
||
tf.reset_default_graph() | ||
# S: Symbol that shows starting of decoding input | ||
# E: Symbol that shows starting of decoding output | ||
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps | ||
sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E'] | ||
|
||
# Transformer Parameters | ||
word_list = " ".join(sentences).split() | ||
word_list = list(set(word_list)) | ||
word_dict = {w: i for i, w in enumerate(word_list)} | ||
n_class = len(word_dict) # vocab list | ||
embedding_size = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
''' | ||
code by Tae Hwan Jung(Jeff Jung) @graykode | ||
Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch | ||
https://github.com/JayParks/transformer | ||
''' | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.autograd import Variable | ||
import torch.nn.functional as F | ||
|
||
dtype = torch.FloatTensor | ||
# S: Symbol that shows starting of decoding input | ||
# E: Symbol that shows starting of decoding output | ||
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps | ||
sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E'] | ||
|
||
# Transformer Parameters | ||
src_vocab = {w: i for i, w in enumerate(sentences[0].split())} | ||
src_vocab_size = len(src_vocab) | ||
tgt_vocab = {w: i for i, w in enumerate(set((sentences[1]+' '+sentences[2]).split()))} | ||
tgt_vocab_size = len(tgt_vocab) | ||
|
||
n_step = 5 # number of Step | ||
d_model = 512 # Embedding Size | ||
d_inner = 2048 | ||
d_k = d_v = 64 # dimension of K(=Q), V | ||
n_layers = 6 # number of Encoder of Decoder Layer | ||
n_heads = 8 # number of heads in Multi-Head Attention | ||
|
||
def make_batch(sentences): | ||
input_batch = [[src_vocab[n] for n in sentences[0].split()]] | ||
output_batch = [[tgt_vocab[n] for n in sentences[1].split()]] | ||
target_batch = [[tgt_vocab[n] for n in sentences[2].split()]] | ||
return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch)) | ||
|
||
def get_sinusoid_encoding_table(n_position, d_model): | ||
def cal_angle(position, hid_idx): | ||
return position / np.power(10000, 2 * (hid_idx // 2) / d_model) | ||
def get_posi_angle_vec(position): | ||
return [cal_angle(position, hid_j) for hid_j in range(d_model)] | ||
|
||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | ||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | ||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | ||
return torch.FloatTensor(sinusoid_table) | ||
|
||
def get_attn_pad_mask(seq_q, seq_k): | ||
batch_size, len_q = seq_q.size() | ||
batch_size, len_k = seq_k.size() | ||
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q) | ||
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k | ||
|
||
class ScaledDotProductAttention(nn.Module): | ||
def __init__(self): | ||
super(ScaledDotProductAttention, self).__init__() | ||
|
||
def forward(self, Q, K, V, attn_mask=None): | ||
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] | ||
if attn_mask is not None: | ||
scores.masked_fill_(attn_mask, -1e9) | ||
attn = nn.Softmax(dim=-1)(scores) | ||
context = torch.matmul(attn, V) | ||
return context, attn | ||
|
||
class MultiHeadAttention(nn.Module): | ||
def __init__(self): | ||
super(MultiHeadAttention, self).__init__() | ||
self.W_Q = nn.Linear(d_model, d_k * n_heads) | ||
self.W_K = nn.Linear(d_model, d_k * n_heads) | ||
self.W_V = nn.Linear(d_model, d_v * n_heads) | ||
def forward(self, Q, K, V, attn_mask=None): | ||
# q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model] | ||
residual, batch_size = Q, Q.size(0) | ||
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k] | ||
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k] | ||
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v] | ||
|
||
if attn_mask is not None: # attn_mask : [batch_size x len_q x len_k] | ||
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k] | ||
# context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)] | ||
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask=attn_mask) | ||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v] | ||
output = nn.Linear(n_heads * d_v, d_model)(context) | ||
return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model] | ||
|
||
class PoswiseFeedForwardNet(nn.Module): | ||
def __init__(self): | ||
super(PoswiseFeedForwardNet, self).__init__() | ||
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_inner, kernel_size=1) | ||
self.conv2 = nn.Conv1d(in_channels=d_inner, out_channels=d_model, kernel_size=1) | ||
|
||
def forward(self, inputs): | ||
residual = inputs # inputs : [batch_size, len_q, d_model] | ||
output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) | ||
output = self.conv2(output).transpose(1, 2) | ||
return nn.LayerNorm(d_model)(output + residual) | ||
|
||
class EncoderLayer(nn.Module): | ||
def __init__(self): | ||
super(EncoderLayer, self).__init__() | ||
self.enc_self_attn = MultiHeadAttention() | ||
self.pos_ffn = PoswiseFeedForwardNet() | ||
|
||
def forward(self, enc_inputs): | ||
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs) # enc_inputs to same Q,K,V | ||
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model] | ||
return enc_outputs, attn | ||
|
||
class DecoderLayer(nn.Module): | ||
def __init__(self): | ||
super(DecoderLayer, self).__init__() | ||
self.dec_self_attn = MultiHeadAttention() | ||
self.dec_enc_attn = MultiHeadAttention() | ||
self.pos_ffn = PoswiseFeedForwardNet() | ||
|
||
def forward(self, dec_inputs, enc_outputs, enc_attn_mask): | ||
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, None) | ||
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, enc_attn_mask) | ||
dec_outputs = self.pos_ffn(dec_outputs) | ||
return dec_outputs, dec_self_attn, dec_enc_attn | ||
|
||
class Encoder(nn.Module): | ||
def __init__(self): | ||
super(Encoder, self).__init__() | ||
self.src_emb = nn.Embedding(src_vocab_size, d_model) | ||
self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(n_step+1 , d_model),freeze=True) | ||
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) | ||
|
||
def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len] | ||
enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]])) | ||
enc_self_attns = [] | ||
for layer in self.layers: | ||
enc_outputs, enc_self_attn = layer(enc_outputs) | ||
enc_self_attns.append(enc_self_attn) | ||
return enc_outputs, enc_self_attns | ||
|
||
class Decoder(nn.Module): | ||
def __init__(self): | ||
super(Decoder, self).__init__() | ||
self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) | ||
self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(n_step+1 , d_model),freeze=True) | ||
self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) | ||
|
||
def forward(self, dec_inputs, enc_inputs, enc_outputs): # dec_inputs : [batch_size x target_len] | ||
dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]])) | ||
dec_enc_attn_pad_mask = get_attn_pad_mask(dec_inputs, enc_inputs) | ||
|
||
dec_self_attns, dec_enc_attns = [], [] | ||
for layer in self.layers: | ||
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, enc_attn_mask=dec_enc_attn_pad_mask) | ||
dec_self_attns.append(dec_self_attn) | ||
dec_enc_attns.append(dec_enc_attn) | ||
return dec_outputs, dec_self_attns, dec_enc_attns | ||
|
||
class Transformer(nn.Module): | ||
def __init__(self): | ||
super(Transformer, self).__init__() | ||
self.encoder = Encoder() | ||
self.decoder = Decoder() | ||
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False) | ||
def forward(self, enc_inputs, dec_inputs): | ||
enc_outputs, enc_self_attns = self.encoder(enc_inputs) | ||
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) | ||
dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size] | ||
return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns | ||
|
||
model = Transformer() | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters(), lr=0.001) | ||
|
||
for epoch in range(100): | ||
optimizer.zero_grad() | ||
enc_inputs, dec_inputs, target_batch = make_batch(sentences) | ||
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs) | ||
loss = criterion(outputs, target_batch.contiguous().view(-1)) | ||
if (epoch + 1) % 10 == 0: | ||
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# Test | ||
predict, _, _, _ = model(enc_inputs, dec_inputs) | ||
predict = predict.data.max(1, keepdim=True)[1] | ||
output = '' | ||
for pre in predict: | ||
for key, value in tgt_vocab.items(): | ||
if value == pre: | ||
output += ' ' + key | ||
print(sentences[0], '->', output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters