Skip to content

Commit

Permalink
Update mmg.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhojsk515 authored Jan 7, 2023
1 parent 909e367 commit 4215fd0
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions mmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,9 @@ def generate_with_property(model, property, tokenizer, device,n_sample,prop_mask
mpm_mask_expand = prop_mask.unsqueeze(0).unsqueeze(2).repeat(property_unk.size(0), 1, property_unk.size(2)).to(device)
property_masked = property1 * (1 - mpm_mask_expand) + property_unk * mpm_mask_expand

#property_masked=property1
#mpm_mask_expand = prop_mask.unsqueeze(0).repeat(property1.size(0), 1)
#mpm_mask_expand = torch.cat([torch.ones((property1.size(0),1)),mpm_mask_expand],dim=1).to(device)


property = torch.cat([model.property_cls.expand(property_masked.size(0),-1,-1),property_masked],dim=1)
prop_embeds = model.property_encoder(inputs_embeds=property,return_dict=True).last_hidden_state #batch*len(=patch**2+1)*feature

#text_input = torch.tensor([2,4]).expand(prop.size(0),2).to(device) #batch*2
text_input = torch.tensor([2]).expand(prop.size(0), 1).to(device) # batch*1

for _ in range(100):
Expand Down

0 comments on commit 4215fd0

Please sign in to comment.