Skip to content

Commit

Permalink
write torchscripted model
Browse files Browse the repository at this point in the history
  • Loading branch information
tyui592 committed Oct 9, 2020
1 parent 9f7a1c4 commit 0283d16
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,47 +99,24 @@ def nms_pytorch(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) ->
if __name__ == '__main__':
from train import load_network
from config import build_parser
import cv2

do_test = False # for debug
device = torch.device('cpu')
args = build_parser()
args.model_load = './WEIGHTS/check_point_30.pth'
args.nms_th = 0.1

normalizer = get_normalizer(pretrained=args.pretrained)
args = build_parser()

# load network
network, _, _, _ = load_network(args, device)
# perdict
#pre = Preprocess()
predictor = Export(network = network,
topk = 10,
topk = args.topk,
scale_factor = args.scale_factor,
conf_th = args.conf_th,
nms_th = args.nms_th,
normalized_coord = args.normalized_coord).to(device)
predictor.eval()

#######################################################################
x = cv2.cvtColor(cv2.imread('../0.jpg'), cv2.COLOR_BGR2RGB)
x = cv2.resize(x, dsize=(512, 512), interpolation=cv2.INTER_AREA)
x = torch.tensor(x) # x: HxWxC, 0.0 ~ 255.0
x = x.permute(2, 0, 1)/255.0
x = normalizer(x).unsqueeze(0)

box_lst, cls_lst, score_lst = predictor(x.to(device))
for i in range(box_lst.shape[0]):
print(', '.join(map(str, box_lst[i].tolist())), ',', cls_lst[i].item(), ',', score_lst[i].item())

############ check the output of python and traced models ################
x = torch.ones(1, 3, 512, 512)
box_lst, cls_lst, score_lst = predictor(x.to(device))

traced_model = torch.jit.trace(predictor, torch.randn(1, 3, 512, 512))
x = torch.ones(1, 3, 512, 512)
box_lst2, cls_lst2, score_lst2 = traced_model(x)
print('output python == output jit: ', torch.all(torch.eq(box_lst, box_lst2)))

##################### model save at cpu #####################
x = torch.randn(1, 3, 512, 512)
traced_model_cpu = torch.jit.trace(predictor.cpu(), x.cpu())
Expand All @@ -151,3 +128,25 @@ def nms_pytorch(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) ->
traced_model_cpu = torch.jit.trace(predictor.cuda(), x.cuda())
torch.jit.save(traced_model_cpu, "jit_traced_model_gpu.pth")
print("Model saved at gpu")

if do_test:
normalizer = get_normalizer(pretrained=args.pretrained)
x = cv2.cvtColor(cv2.imread('../0.jpg'), cv2.COLOR_BGR2RGB)
x = cv2.resize(x, dsize=(512, 512), interpolation=cv2.INTER_AREA)
x = torch.tensor(x) # x: HxWxC, 0.0 ~ 255.0
x = x.permute(2, 0, 1)/255.0
x = normalizer(x).unsqueeze(0)

box_lst, cls_lst, score_lst = predictor(x.to(device))
for i in range(box_lst.shape[0]):
print(', '.join(map(str, box_lst[i].tolist())), ',', cls_lst[i].item(), ',', score_lst[i].item())

############ check the output of python and traced models ################
x = torch.ones(1, 3, 512, 512)
box_lst, cls_lst, score_lst = predictor(x.to(device))

traced_model = torch.jit.trace(predictor, torch.randn(1, 3, 512, 512))
x = torch.ones(1, 3, 512, 512)
box_lst2, cls_lst2, score_lst2 = traced_model(x)
print('output python == output jit: ', torch.all(torch.eq(box_lst, box_lst2)))

0 comments on commit 0283d16

Please sign in to comment.