Skip to content

Commit 88734e5

Browse files
authored
Add files via upload
1 parent 8165522 commit 88734e5

5 files changed

+73
-21
lines changed

SPMM_models.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.distributed
66
import pytorch_lightning as pl
77
from scheduler import create_scheduler
8+
import random
89

910

1011
class AttrDict(dict):
@@ -16,6 +17,7 @@ def __init__(self, *args, **kwargs):
1617
class SPMM(pl.LightningModule):
1718
def __init__(self, tokenizer=None, config=None, loader_len=0, no_train=False):
1819
super().__init__()
20+
self.save_hyperparameters()
1921
self.automatic_optimization = False
2022
self.config = config
2123
self.tokenizer = tokenizer
@@ -82,13 +84,16 @@ def forward(self, property_original, text_input_ids, text_attention_mask, alpha=
8284
property_feature = self.property_embed(property_original.unsqueeze(2))
8385

8486
unk_tokens = self.property_mask.expand(property_original.size(0), property_original.size(1), -1)
85-
mpm_mask = torch.bernoulli(torch.ones_like(property_original) * 0.5)
87+
if random.random() < 0.05:
88+
mpm_mask = torch.ones_like(property_original) # all mask
89+
else:
90+
mpm_mask = torch.bernoulli(torch.ones_like(property_original) * 0.5) # 1 for mask, 0 for keep
8691
mpm_mask_expand = mpm_mask.unsqueeze(2).repeat(1, 1, unk_tokens.size(2))
8792
property_masked = property_feature * (1 - mpm_mask_expand) + unk_tokens * mpm_mask_expand
88-
property = torch.cat([self.property_cls.expand(property_original.size(0), -1, -1), property_masked], dim=1)
93+
properties = torch.cat([self.property_cls.expand(property_original.size(0), -1, -1), property_masked], dim=1)
8994

90-
prop_embeds = self.property_encoder(inputs_embeds=property, return_dict=True).last_hidden_state
91-
prop_atts = torch.ones(prop_embeds.size()[:-1], dtype=torch.long).to(property.device)
95+
prop_embeds = self.property_encoder(inputs_embeds=properties, return_dict=True).last_hidden_state
96+
prop_atts = torch.ones(prop_embeds.size()[:-1], dtype=torch.long).to(properties.device)
9297
prop_feat = F.normalize(self.property_proj(prop_embeds[:, 0, :]), dim=-1)
9398

9499
text_embeds = self.text_encoder.bert(text_input_ids, attention_mask=text_attention_mask, return_dict=True, mode='text').last_hidden_state
@@ -97,7 +102,7 @@ def forward(self, property_original, text_input_ids, text_attention_mask, alpha=
97102

98103
with torch.no_grad():
99104
self._momentum_update()
100-
prop_embeds_m = self.property_encoder_m(inputs_embeds=property, return_dict=True).last_hidden_state
105+
prop_embeds_m = self.property_encoder_m(inputs_embeds=properties, return_dict=True).last_hidden_state
101106
prop_feat_m = F.normalize(self.property_proj_m(prop_embeds_m[:, 0, :]), dim=-1)
102107
prop_feat_all = torch.cat([prop_feat_m.t(), self.prop_queue.clone().detach()], dim=1)
103108

@@ -110,7 +115,7 @@ def forward(self, property_original, text_input_ids, text_attention_mask, alpha=
110115
sim_i2i_m = prop_feat_m @ prop_feat_all / self.temp
111116
sim_t2t_m = text_feat_m @ text_feat_all / self.temp
112117

113-
sim_targets = torch.zeros(sim_i2t_m.size()).to(property.device)
118+
sim_targets = torch.zeros(sim_i2t_m.size()).to(properties.device)
114119
sim_targets.fill_diagonal_(1)
115120

116121
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
@@ -268,8 +273,8 @@ def _momentum_update(self):
268273

269274
@torch.no_grad()
270275
def _dequeue_and_enqueue(self, img_feat, text_feat):
271-
img_feats = img_feat
272-
text_feats = text_feat
276+
img_feats = concat_all_gather(img_feat)
277+
text_feats = concat_all_gather(text_feat)
273278

274279
batch_size = img_feats.shape[0]
275280

@@ -354,3 +359,16 @@ def on_train_epoch_end(self): # outputs: collection of returns from 'training
354359
if self.global_rank == 0:
355360
print(f'\n mean loss: {tmp[0]:.4f}, {tmp[1]:.4f}, {tmp[2]:.4f}, {tmp[3]:.4f}')
356361
self.training_step_outputs.clear()
362+
363+
364+
@torch.no_grad()
365+
def concat_all_gather(tensor):
366+
"""
367+
Performs all_gather operation on the provided tensors.
368+
*** Warning ***: torch.distributed.all_gather has no gradient.
369+
"""
370+
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
371+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
372+
373+
output = torch.cat(tensors_gather, dim=0)
374+
return output

SPMM_pretrain.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def main(args, config):
1313
# data
1414
print("Creating dataset")
15-
dataset = SMILESDataset_pretrain(args.data_path)
15+
dataset = SMILESDataset_pretrain(args.data_path, data_length=[0, 10000])
1616
print('#data:', len(dataset))
1717
data_loader = DataLoader(dataset, batch_size=config['batch_size'], num_workers=8, shuffle=True, pin_memory=True, drop_last=True)
1818
tokenizer = BertTokenizer(vocab_file=args.vocab_filename, do_lower_case=False, do_basic_tokenize=False)
@@ -27,17 +27,18 @@ def main(args, config):
2727
# training
2828
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=args.output_dir, filename='checkpoint_{epoch}',
2929
save_top_k=config['schedular']['epochs'], monitor='loss_mlm')
30-
trainer = pl.Trainer(accelerator='gpu', devices=[0], precision=16, max_epochs=config['schedular']['epochs'],
30+
trainer = pl.Trainer(accelerator='gpu', devices=[0, 1], precision=16, max_epochs=config['schedular']['epochs'],
3131
callbacks=[checkpoint_callback], strategy=DDPStrategy(find_unused_parameters=True), limit_val_batches=0.)
3232
trainer.fit(model, data_loader, None, ckpt_path=args.checkpoint if args.checkpoint else None)
3333

3434

3535
if __name__ == '__main__':
3636
parser = argparse.ArgumentParser()
3737
parser.add_argument('--checkpoint', default='')
38-
parser.add_argument('--data_path', default='./data/1_Pretrain/pretrain_20m.txt')
38+
# parser.add_argument('--data_path', default='./data/1_Pretrain/pretrain_20m.txt')
39+
parser.add_argument('--data_path', default='../VLP_chem/data/pubchem-100m-simple-shuffle.txt')
3940
parser.add_argument('--resume', default=False, type=bool)
40-
parser.add_argument('--output_dir', default='./checkpoints')
41+
parser.add_argument('--output_dir', default='./Pretrain')
4142
parser.add_argument('--vocab_filename', default='./vocab_bpe_300.txt')
4243
parser.add_argument('--seed', default=42, type=int)
4344
args = parser.parse_args()

d_pv2smiles_stochastic.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515

1616
@torch.no_grad()
17-
def generate_with_property(model, property, tokenizer, device, n_sample, prop_mask):
17+
def generate_with_property(model, property, n_sample, prop_mask, stochastic=True):
18+
device = model.device
19+
tokenizer = model.tokenizer
1820
# test
1921
model.eval()
2022
print("PV-to-SMILES generation in stochastic manner...")
@@ -40,7 +42,7 @@ def generate_with_property(model, property, tokenizer, device, n_sample, prop_ma
4042
text_input = torch.tensor([tokenizer.cls_token_id]).expand(prop.size(0), 1).to(device)
4143
end_count = torch.zeros_like(text_input).to(bool)
4244
for _ in range(100):
43-
output = generate(model, prop_embeds, text_input, stochastic=True)
45+
output = generate(model, prop_embeds, text_input, stochastic=stochastic)
4446
end_count = torch.logical_or(end_count, (output == tokenizer.sep_token_id))
4547
if end_count.all():
4648
break
@@ -160,7 +162,7 @@ def main(args, config):
160162
# prop_input = torch.zeros(53)
161163

162164
print("=" * 50)
163-
samples = generate_with_property(model, prop_input, tokenizer, device, args.n_generate, prop_mask)
165+
samples = generate_with_property(model, prop_input, args.n_generate, prop_mask)
164166
metric_eval(prop_input, samples, prop_mask)
165167
print("=" * 50)
166168

d_smiles2pv.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,33 @@ def generate(model, prop_input, text_embeds, text_atts):
2828

2929

3030
@torch.no_grad()
31-
def evaluate(model, data_loader, tokenizer, device):
31+
def pv_generate(model, data_loader):
3232
# test
33+
with open('./normalize.pkl', 'rb') as w:
34+
mean, std = pickle.load(w)
35+
device = model.device
36+
tokenizer = model.tokenizer
3337
model.eval()
3438
print("SMILES-to-PV generation...")
39+
# convert list of string to dataloader
40+
if isinstance(data_loader, list):
41+
gather = []
42+
text_input = tokenizer(data_loader, padding='longest', truncation=True, max_length=100, return_tensors="pt").to(device)
43+
text_embeds = model.text_encoder.bert(text_input.input_ids[:, 1:], attention_mask=text_input.attention_mask[:, 1:],
44+
return_dict=True, mode='text').last_hidden_state
45+
prop_input = model.property_cls.expand(len(data_loader), -1, -1)
46+
prediction = []
47+
for _ in range(53):
48+
output = generate(model, prop_input, text_embeds, text_input.attention_mask[:, 1:])
49+
prediction.append(output)
50+
output = model.property_embed(output.unsqueeze(2))
51+
prop_input = torch.cat([prop_input, output], dim=1)
52+
53+
prediction = torch.stack(prediction, dim=-1)
54+
for i in range(len(data_loader)):
55+
gather.append(prediction[i].cpu()*std + mean)
56+
return gather
57+
3558
reference, candidate = [], []
3659
for (prop, text) in data_loader:
3760
text_input = tokenizer(text, padding='longest', truncation=True, max_length=100, return_tensors="pt").to(device)
@@ -139,7 +162,7 @@ def main(args, config):
139162
model = model.to(device)
140163

141164
print("=" * 50)
142-
r_test, c_test = evaluate(model, test_loader, tokenizer, device)
165+
r_test, c_test = pv_generate(model, test_loader)
143166
metric_eval(r_test, c_test)
144167
print("=" * 50)
145168

dataset.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,23 @@
1212

1313
class SMILESDataset_pretrain(Dataset):
1414
def __init__(self, data_path, data_length=None, shuffle=False):
15-
with open(data_path, 'r') as f:
16-
lines = f.readlines()
15+
if data_length is not None:
16+
with open(data_path, 'r') as f:
17+
for _ in range(data_length[0]):
18+
f.readline()
19+
lines = []
20+
for _ in range(data_length[1] - data_length[0]):
21+
lines.append(f.readline())
22+
else:
23+
with open(data_path, 'r') as f:
24+
lines = f.readlines()
1725
self.data = [l.strip() for l in lines]
1826
with open('./normalize.pkl', 'rb') as w:
1927
norm = pickle.load(w)
2028
self.property_mean, self.property_std = norm
2129

22-
if shuffle: random.shuffle(self.data)
23-
if data_length is not None: self.data = self.data[data_length[0]:data_length[1]]
30+
if shuffle:
31+
random.shuffle(self.data)
2432

2533
def __len__(self):
2634
return len(self.data)

0 commit comments

Comments
 (0)