5
5
import torch .distributed
6
6
import pytorch_lightning as pl
7
7
from scheduler import create_scheduler
8
+ import random
8
9
9
10
10
11
class AttrDict (dict ):
@@ -16,6 +17,7 @@ def __init__(self, *args, **kwargs):
16
17
class SPMM (pl .LightningModule ):
17
18
def __init__ (self , tokenizer = None , config = None , loader_len = 0 , no_train = False ):
18
19
super ().__init__ ()
20
+ self .save_hyperparameters ()
19
21
self .automatic_optimization = False
20
22
self .config = config
21
23
self .tokenizer = tokenizer
@@ -82,13 +84,16 @@ def forward(self, property_original, text_input_ids, text_attention_mask, alpha=
82
84
property_feature = self .property_embed (property_original .unsqueeze (2 ))
83
85
84
86
unk_tokens = self .property_mask .expand (property_original .size (0 ), property_original .size (1 ), - 1 )
85
- mpm_mask = torch .bernoulli (torch .ones_like (property_original ) * 0.5 )
87
+ if random .random () < 0.05 :
88
+ mpm_mask = torch .ones_like (property_original ) # all mask
89
+ else :
90
+ mpm_mask = torch .bernoulli (torch .ones_like (property_original ) * 0.5 ) # 1 for mask, 0 for keep
86
91
mpm_mask_expand = mpm_mask .unsqueeze (2 ).repeat (1 , 1 , unk_tokens .size (2 ))
87
92
property_masked = property_feature * (1 - mpm_mask_expand ) + unk_tokens * mpm_mask_expand
88
- property = torch .cat ([self .property_cls .expand (property_original .size (0 ), - 1 , - 1 ), property_masked ], dim = 1 )
93
+ properties = torch .cat ([self .property_cls .expand (property_original .size (0 ), - 1 , - 1 ), property_masked ], dim = 1 )
89
94
90
- prop_embeds = self .property_encoder (inputs_embeds = property , return_dict = True ).last_hidden_state
91
- prop_atts = torch .ones (prop_embeds .size ()[:- 1 ], dtype = torch .long ).to (property .device )
95
+ prop_embeds = self .property_encoder (inputs_embeds = properties , return_dict = True ).last_hidden_state
96
+ prop_atts = torch .ones (prop_embeds .size ()[:- 1 ], dtype = torch .long ).to (properties .device )
92
97
prop_feat = F .normalize (self .property_proj (prop_embeds [:, 0 , :]), dim = - 1 )
93
98
94
99
text_embeds = self .text_encoder .bert (text_input_ids , attention_mask = text_attention_mask , return_dict = True , mode = 'text' ).last_hidden_state
@@ -97,7 +102,7 @@ def forward(self, property_original, text_input_ids, text_attention_mask, alpha=
97
102
98
103
with torch .no_grad ():
99
104
self ._momentum_update ()
100
- prop_embeds_m = self .property_encoder_m (inputs_embeds = property , return_dict = True ).last_hidden_state
105
+ prop_embeds_m = self .property_encoder_m (inputs_embeds = properties , return_dict = True ).last_hidden_state
101
106
prop_feat_m = F .normalize (self .property_proj_m (prop_embeds_m [:, 0 , :]), dim = - 1 )
102
107
prop_feat_all = torch .cat ([prop_feat_m .t (), self .prop_queue .clone ().detach ()], dim = 1 )
103
108
@@ -110,7 +115,7 @@ def forward(self, property_original, text_input_ids, text_attention_mask, alpha=
110
115
sim_i2i_m = prop_feat_m @ prop_feat_all / self .temp
111
116
sim_t2t_m = text_feat_m @ text_feat_all / self .temp
112
117
113
- sim_targets = torch .zeros (sim_i2t_m .size ()).to (property .device )
118
+ sim_targets = torch .zeros (sim_i2t_m .size ()).to (properties .device )
114
119
sim_targets .fill_diagonal_ (1 )
115
120
116
121
sim_i2t_targets = alpha * F .softmax (sim_i2t_m , dim = 1 ) + (1 - alpha ) * sim_targets
@@ -268,8 +273,8 @@ def _momentum_update(self):
268
273
269
274
@torch .no_grad ()
270
275
def _dequeue_and_enqueue (self , img_feat , text_feat ):
271
- img_feats = img_feat
272
- text_feats = text_feat
276
+ img_feats = concat_all_gather ( img_feat )
277
+ text_feats = concat_all_gather ( text_feat )
273
278
274
279
batch_size = img_feats .shape [0 ]
275
280
@@ -354,3 +359,16 @@ def on_train_epoch_end(self): # outputs: collection of returns from 'training
354
359
if self .global_rank == 0 :
355
360
print (f'\n mean loss: { tmp [0 ]:.4f} , { tmp [1 ]:.4f} , { tmp [2 ]:.4f} , { tmp [3 ]:.4f} ' )
356
361
self .training_step_outputs .clear ()
362
+
363
+
364
+ @torch .no_grad ()
365
+ def concat_all_gather (tensor ):
366
+ """
367
+ Performs all_gather operation on the provided tensors.
368
+ *** Warning ***: torch.distributed.all_gather has no gradient.
369
+ """
370
+ tensors_gather = [torch .ones_like (tensor ) for _ in range (torch .distributed .get_world_size ())]
371
+ torch .distributed .all_gather (tensors_gather , tensor , async_op = False )
372
+
373
+ output = torch .cat (tensors_gather , dim = 0 )
374
+ return output
0 commit comments