Skip to content

Commit

Permalink
Add warning about the weight sharing.
Browse files Browse the repository at this point in the history
  • Loading branch information
jadore801120 committed Jun 19, 2017
1 parent a3808ab commit 6d68f49
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
1 change: 1 addition & 0 deletions DataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def tgt_word2idx(self):
def src_idx2word(self):
''' Property for index dictionary '''
return self._src_idx2word

@property
def tgt_idx2word(self):
''' Property for index dictionary '''
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python preprocess.py -train_src <train.src.txt> -train_tgt <train.tgt.txt> -vali

## 1) Training
```bash
python train.py -data <output.pt>
python train.py -data <output.pt> -embs_share_weight -proj_share_weight
```
## 2) Testing
### TODO
Expand Down
8 changes: 4 additions & 4 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def main():
'src': src_word2idx,
'tgt': tgt_word2idx},
'train': {
'src':train_src_insts,
'tgt':train_tgt_insts},
'src': train_src_insts,
'tgt': train_tgt_insts},
'valid': {
'src':valid_tgt_insts,
'tgt':valid_src_insts}}
'src': valid_tgt_insts,
'tgt': valid_src_insts}}

print('[Info] Dumping the processed data to pickle file', opt.output)
torch.save(data, opt.output)
Expand Down
13 changes: 10 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def main():
parser.add_argument('-n_warmup_steps', type=int, default=4000)

parser.add_argument('-dropout', type=float, default=0.5)
parser.add_argument('-embs_share_weight', action='store_true')
parser.add_argument('-proj_share_weight', action='store_true')

parser.add_argument('-no_cuda', action='store_true')

Expand All @@ -166,20 +168,25 @@ def main():
batch_size=opt.batch_size)

#========= Preparing Model =========#

if opt.embs_share_weight and training_data.src_word2idx != training_data.tgt_word2idx:
print('[Warning]',
'The src/tgt word2idx table are different but asked to share word embedding.')

transformer = Transformer(
training_data.src_vocab_size,
training_data.tgt_vocab_size,
data['setting'].max_seq_len,
proj_share_weight=True,
embs_share_weight=True,
proj_share_weight=opt.proj_share_weight,
embs_share_weight=opt.embs_share_weight,
d_model=opt.d_model,
d_word_vec=opt.d_word_vec,
d_inner_hid=opt.d_inner_hid,
n_layers=opt.n_layers,
n_head=opt.n_head,
dropout=opt.dropout)

print(transformer)
#print(transformer)

optimizer = optim.Adam(
transformer.get_trainable_parameters(),
Expand Down
3 changes: 2 additions & 1 deletion transformer/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def __init__(
if embs_share_weight:
# Share the weight matrix between src/tgt word embeddings
# assume the src/tgt word vec size are the same
assert n_src_vocab == n_tgt_vocab
assert n_src_vocab == n_tgt_vocab, \
"To share word embedding table, the vocabulary size of src/tgt shall be the same."
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight

def get_trainable_parameters(self):
Expand Down

0 comments on commit 6d68f49

Please sign in to comment.