-
Notifications
You must be signed in to change notification settings - Fork 68
/
matching.py
122 lines (91 loc) · 4.84 KB
/
matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import os, pickle
from PIL import Image, ImageDraw, ImageFont
from viz_hand_obj import *
def calculate_center(bb):
return [(bb[0] + bb[2])/2, (bb[1] + bb[3])/2]
def filter_object(obj_dets, hand_dets):
object_cc_list = [] # object center list
for j in range(obj_dets.shape[0]):
object_cc_list.append(calculate_center(obj_dets[j,:4]))
object_cc_list = np.array(object_cc_list)
img_obj_id = [] # matching list
for i in range(hand_dets.shape[0]):
if hand_dets[i, 5] <= 0: # if hand is non-contact
img_obj_id.append(-1)
continue
else: # hand is in-contact
hand_cc = np.array(calculate_center(hand_dets[i,:4])) # hand center points
point_cc = np.array([(hand_cc[0]+hand_dets[i,6]*10000*hand_dets[i,7]), (hand_cc[1]+hand_dets[i,6]*10000*hand_dets[i,8])]) # extended points (hand center + offset)
dist = np.sum((object_cc_list - point_cc)**2,axis=1)
dist_min = np.argmin(dist) # find the nearest
img_obj_id.append(dist_min)
return img_obj_id
if __name__ == '__main__':
##############################################################
# save the detection results in a pickle file, in dictionary format like
# {"input_image_path_1":{
# "hand_dets": hand_dets
# ""'obj_dets": obj_dets
# }
# "input_image_path_2":{
# "hand_dets": hand_dets
# "obj_dets": obj_dets
# }
# ....
# }
# with img_obj_id generated by filter_object() function, hand_dets and obj_dets can be matched
##############################################################
pickle_path = 'path/to/saved/pickle/file'
save_dir = 'path/to/save'
os.makedirs(save_dir, exist_ok=True)
thresh_hand = 0.5
thresh_obj = 0.5
viz = True
with open(pickle_path, 'rb') as f:
pickle_info = pickle.load(f)
for image_path, image_info in pickle_info.items():
hand_dets = image_info['hand_dets']
obj_dets = image_info['obj_dets']
image_name = os.path.split(image_path)[-1]
if viz:
image = Image.open(image_path).convert("RGBA")
draw = ImageDraw.Draw(image)
font = ImageFont.truetype('times_b.ttf', size=30)
width, height = image.size
if (obj_dets is not None) and (hand_dets is not None):
# get matching list
img_obj_id = filter_object(obj_dets, hand_dets)
# obj
for obj_idx, i in enumerate(range(np.minimum(10, obj_dets.shape[0]))):
obj_bbox = list(int(np.round(x)) for x in obj_dets[i, :4])
obj_score = obj_dets[i, 4]
# viz
if viz and obj_score > thresh_obj and i in img_obj_id: # draw obj if > threshold and matched with a hand
image = draw_obj_mask(image, draw, obj_idx, obj_bbox, obj_score, width, height, font)
# hand
for hand_idx, i in enumerate(range(np.minimum(10, hand_dets.shape[0]))):
hand_bbox = list(int(np.round(x)) for x in hand_dets[i, :4])
hand_score = hand_dets[i, 4]
hand_state = hand_dets[i, 5]
hand_vec = hand_dets[i, 6:9]
hand_lr = hand_dets[i, -1]
matched_obj = obj_dets[img_obj_id[i],:4]
# viz
if viz and hand_score > thresh_hand:
image = draw_hand_mask(image, draw, hand_idx, hand_bbox, hand_score, hand_lr, hand_state, width, height, font)
if hand_state > 0: # in contact hand
obj_cc, hand_cc = calculate_center(matched_obj), calculate_center(hand_bbox)
side_idx = int(hand_lr)
draw_line_point(draw, side_idx, (int(hand_cc[0]), int(hand_cc[1])), (int(obj_cc[0]), int(obj_cc[1])))
elif hand_dets is not None:
for hand_idx, i in enumerate(range(np.minimum(10, hand_dets.shape[0]))):
hand_bbox = list(int(np.round(x)) for x in hand_dets[i, :4])
hand_score = hand_dets[i, 4]
hand_state = hand_dets[i, 5]
hand_lr = hand_dets[i, -1]
if viz and hand_score > thresh_hand:
image = draw_hand_mask(image, draw, hand_idx, hand_bbox, hand_score, hand_lr, hand_state, width, height, font)
if viz:
save_path = os.path.join(save_dir, image_name)[:-4] +'_det.png'
image.save(save_path)