-
Notifications
You must be signed in to change notification settings - Fork 89
/
inference.py
129 lines (90 loc) · 4.1 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os, sys
sys.path.append("../")
import tensorflow as tf
import time
import cv2
import numpy as np
import argparse
from data.io.image_preprocess import short_side_resize_for_inference_data
from libs.configs import cfgs
from libs.networks import build_whole_network
from help_utils.tools import *
from libs.box_utils import draw_box_in_img
from help_utils import tools
def inference(det_net, data_dir):
# 1. preprocess img
img_plac = tf.placeholder(dtype=tf.uint8, shape=[None, None, 3])
img_batch = tf.cast(img_plac, tf.float32)
img_batch = img_batch - tf.constant(cfgs.PIXEL_MEAN)
img_batch = short_side_resize_for_inference_data(img_tensor=img_batch,
target_shortside_len=cfgs.IMG_SHORT_SIDE_LEN)
det_boxes_r, det_scores_r, det_category_r = det_net.build_whole_detection_network(input_img_batch=img_batch,
gtboxes_batch=None)
init_op = tf.group(
tf.global_variables_initializer(),
tf.local_variables_initializer()
)
restorer, restore_ckpt = det_net.get_restorer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(init_op)
if not restorer is None:
restorer.restore(sess, restore_ckpt)
print('restore model')
imgs = os.listdir(data_dir)
for i, a_img_name in enumerate(imgs):
# f = open('./res_icdar_r/res_{}.txt'.format(a_img_name.split('.jpg')[0]), 'w')
raw_img = cv2.imread(os.path.join(data_dir,
a_img_name))
# raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]
start = time.time()
resized_img, det_boxes_r_, det_scores_r_, det_category_r_ = \
sess.run(
[img_batch, det_boxes_r, det_scores_r, det_category_r],
feed_dict={img_plac: raw_img}
)
end = time.time()
# res_r = coordinate_convert.forward_convert(det_boxes_r_, False)
# res_r = np.array(res_r, np.int32)
# for r in res_r:
# f.write('{},{},{},{},{},{},{},{}\n'.format(r[0], r[1], r[2], r[3],
# r[4], r[5], r[6], r[7]))
# f.close()
det_detections_r = draw_box_in_img.draw_rotate_box_cv(np.squeeze(resized_img, 0),
boxes=det_boxes_r_,
labels=det_category_r_,
scores=det_scores_r_)
save_dir = os.path.join(cfgs.INFERENCE_SAVE_PATH, cfgs.VERSION)
tools.mkdir(save_dir)
cv2.imwrite(save_dir + '/' + a_img_name + '_r.jpg',
det_detections_r)
view_bar('{} cost {}s'.format(a_img_name, (end - start)), i + 1, len(imgs))
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast RRPN network')
parser.add_argument('--data_dir', dest='data_dir',
help='data path',
default='/mnt/USBC/gx/Detection/icdar2015/ch4_test_images/', type=str)
parser.add_argument('--gpu', dest='gpu',
help='gpu index',
default='0', type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
det_net = build_whole_network.DetectionNetwork(base_network_name=cfgs.NET_NAME,
is_training=False)
inference(det_net, data_dir=args.data_dir)