forked from jinhojsk515/SPMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSPMM_pretrain.py
62 lines (55 loc) · 2.99 KB
/
SPMM_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from torch.utils.data import DataLoader
from dataset import SMILESDataset_pretrain
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
import torch.distributed
from SPMM_models import SPMM
import argparse
from pathlib import Path
from transformers import BertTokenizer, WordpieceTokenizer
def main(args, config):
# data
print("Creating dataset")
dataset = SMILESDataset_pretrain(args.data_path, data_length=[0, 10000])
print('#data:', len(dataset))
data_loader = DataLoader(dataset, batch_size=config['batch_size'], num_workers=8, shuffle=True, pin_memory=True, drop_last=True)
tokenizer = BertTokenizer(vocab_file=args.vocab_filename, do_lower_case=False, do_basic_tokenize=False)
tokenizer.wordpiece_tokenizer = WordpieceTokenizer(vocab=tokenizer.vocab, unk_token=tokenizer.unk_token, max_input_chars_per_word=250)
# model
model = SPMM(config=config, tokenizer=tokenizer, loader_len=len(data_loader) // torch.cuda.device_count())
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location='cpu')
_ = model.load_state_dict(checkpoint['state_dict'], strict=False)
# training
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=args.output_dir, filename='checkpoint_{epoch}',
save_top_k=config['schedular']['epochs'], monitor='loss_mlm')
trainer = pl.Trainer(accelerator='gpu', devices=[0, 1], precision=16, max_epochs=config['schedular']['epochs'],
callbacks=[checkpoint_callback], strategy=DDPStrategy(find_unused_parameters=True), limit_val_batches=0.)
trainer.fit(model, data_loader, None, ckpt_path=args.checkpoint if args.checkpoint else None)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', default='')
# parser.add_argument('--data_path', default='./data/1_Pretrain/pretrain_20m.txt')
parser.add_argument('--data_path', default='../VLP_chem/data/pubchem-100m-simple-shuffle.txt')
parser.add_argument('--resume', default=False, type=bool)
parser.add_argument('--output_dir', default='./Pretrain')
parser.add_argument('--vocab_filename', default='./vocab_bpe_300.txt')
parser.add_argument('--seed', default=42, type=int)
args = parser.parse_args()
pretrain_config = {
'property_width': 768,
'embed_dim': 256,
'batch_size': 8,
'temp': 0.07,
'mlm_probability': 0.15,
'queue_size': 32768,
'momentum': 0.995,
'alpha': 0.4,
'bert_config_text': './config_bert.json',
'bert_config_property': './config_bert_property.json',
'schedular': {'sched': 'cosine', 'lr': 1e-4, 'epochs': 30, 'min_lr': 1e-5,
'decay_rate': 1, 'warmup_lr': 5e-5, 'warmup_epochs': 20, 'cooldown_epochs': 0},
'optimizer': {'opt': 'adamW', 'lr': 1e-4, 'weight_decay': 0.02}
}
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, pretrain_config)