forked from clovaai/deep-text-recognition-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtransform_to_onnx.py
106 lines (87 loc) · 4.74 KB
/
transform_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import string
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
from utils import CTCLabelConverter, AttnLabelConverter
from dataset import RawDataset, AlignCollate
from model import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Post_Model(torch.nn.Module):
def __init__(self, model):
super(Post_Model, self).__init__()
self.model = model
def forward(self, input, is_tran=True ):
output = self.model(input, is_tran)
output = output.permute(1, 0, 2)
return output
def transform_to_onnx(opt):
""" model configuration """
if 'CTC' in opt.Prediction:
converter = CTCLabelConverter(opt.character)
else:
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)
if opt.rgb:
opt.input_channel = 3
model = Model(opt).to(device)
print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
opt.SequenceModeling, opt.Prediction)
# load model
print('loading pretrained model from %s' % opt.saved_model)
model.load_state_dict(torch.load(opt.saved_model, map_location=device))
# change output, so that OpenCV sample can run it correctly
pose_model = Post_Model(model)
model = pose_model
# transform to onnx
model.eval()
batch_size = 1
channel = 1
h_size = 32
w_size = 100
dummy_input = torch.randn(batch_size, channel, h_size, w_size, requires_grad=True).to(device)
output_path = "./onnx_output"
if not os.path.exists(output_path):
os.mkdir(output_path)
onnx_output = os.path.join(output_path, opt.onnx_name)
print("output path :", onnx_output)
dummy_output = model(dummy_input)
print("output size :", dummy_output.size())
torch.onnx.export(model, dummy_input, onnx_output, verbose=True, export_params=True, opset_version=11, do_constant_folding=True, input_names = ["input"], output_names = ["output"])
print("Transform Successfully!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
parser.add_argument('--saved_model', required=True, help="path to load saved_model")
parser.add_argument('--onnx_name', required=True, help="name of output onnx model")
""" Data processing """
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
parser.add_argument('--rgb', action='store_true', help='use rgb input')
parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
""" Model Architecture """
parser.add_argument('--Transformation', type=str, default="None", help='Transformation stage. None, OpenCV dose not support Transformation model.')
parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet|DenseNet')
parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel', type=int, default=512,
help='the number of output channel of Feature extractor')
parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
opt = parser.parse_args()
""" vocab / character number configuration """
if opt.sensitive:
opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).
cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()
transform_to_onnx(opt)