Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
radekd91 committed Jan 25, 2024
1 parent 43692d0 commit efa04d0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion inferno/models/mica/MicaInputProcessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _dirty_image_preprocessing(self, input_image):
kps = kpss[bb_i]

face = Face(bbox=bbox, kps=kps, det_score=det_score)
blob, aimg = get_arcface_input(face, img)
blob, aimg = get_arcface_input(face, img, image_is_bgr=False)
aligned_image_list.append(aimg)
aligned_images = np.array(aligned_image_list)
# b,h,w,c to b,c,h,w
Expand Down
16 changes: 16 additions & 0 deletions inferno/models/temporal/Preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def forward(self, batch, input_key, *args, output_prefix="gt_", test_time=False,
for key in batch['landmarks'].keys():
batch_['landmarks'][key] = batch['landmarks'][key].view(B*T, -1, 2)

if 'mica_video' in batch:
batch_['mica_video'] = batch['mica_video'].view(B*T, *batch['mica_video'].shape[-3:])

values = self.model(batch_, training=False, validation=False)
else:
outputs = []
Expand All @@ -270,8 +273,21 @@ def forward(self, batch, input_key, *args, output_prefix="gt_", test_time=False,
batch_['landmarks'] = {}
for key in batch['landmarks'].keys():
batch_['landmarks'][key] = batch['landmarks'][key].view(B*T, -1, 2)[i:i+self.max_b]
if 'mica_video' in batch:
batch_['mica_video'] = batch['mica_video'].view(B*T, *batch['mica_video'].shape[-3:])[i:i+self.max_b]
out = self.model(batch_, training=False, validation=False)
outputs.append(out)

if 'image' in out:
del out['image']
if 'mica_images' in out:
del out['mica_images']
if 'predicted_image' in out:
del out['predicted_image']
if 'predicted_mask' in out:
del out['predicted_mask']
if 'albedo' in out:
del out['albedo']

# combine into a single output
values = cat_tensor_or_dict(outputs, dim=0)
Expand Down

0 comments on commit efa04d0

Please sign in to comment.