Skip to content

Commit 528e2b7

Browse files
author
Theodore Zhao
committed
Bug fix image resizing in inference
1 parent c970f9c commit 528e2b7

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

inference_utils/inference.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import random
2020

2121
t = []
22-
t.append(transforms.Resize(1024, interpolation=Image.BICUBIC))
22+
t.append(transforms.Resize((1024, 1024), interpolation=Image.BICUBIC))
2323
transform = transforms.Compose(t)
2424
#metadata = MetadataCatalog.get('coco_2017_train_panoptic')
2525
all_classes = ['background'] + [name.replace('-other','').replace('-merged','')
@@ -37,12 +37,11 @@
3737
@torch.no_grad()
3838
def interactive_infer_image(model, image, prompts):
3939

40-
image_ori = transform(image)
41-
#mask_ori = image['mask']
42-
width = image_ori.size[0]
43-
height = image_ori.size[1]
44-
image_ori = np.asarray(image_ori)
45-
image = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
40+
image_resize = transform(image)
41+
width = image.size[0]
42+
height = image.size[1]
43+
image_resize = np.asarray(image_resize)
44+
image = torch.from_numpy(image_resize.copy()).permute(2,0,1).cuda()
4645

4746
data = {"image": image, 'text': prompts, "height": height, "width": width}
4847

@@ -72,7 +71,8 @@ def interactive_infer_image(model, image, prompts):
7271
pred_class = results['pred_logits'][0][matched_id].max(dim=-1)[1]
7372

7473
# interpolate mask to ori size
75-
pred_mask_prob = F.interpolate(pred_masks_pos[None,], image_size[-2:], mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
74+
pred_mask_prob = F.interpolate(pred_masks_pos[None,], (data['height'], data['width']),
75+
mode='bilinear')[0,:,:data['height'],:data['width']].sigmoid().cpu().numpy()
7676
pred_masks_pos = (1*(pred_mask_prob > 0.5)).astype(np.uint8)
7777

7878
return pred_mask_prob

0 commit comments

Comments
 (0)