diff --git a/lab3/lab3.py b/lab3/lab3.py index f15b058..d37cb98 100644 --- a/lab3/lab3.py +++ b/lab3/lab3.py @@ -25,39 +25,40 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# Load regular model -# Create regular pytorch model +if not trt: -print("loading model...") + print("loading model...") -timest = time.time() - -# model = torch.hub.load('pytorch/vision:v0.8.0', 'wide_resnet101_2', pretrained=True).eval().cuda() -model = alexnet(pretrained = True).eval().cuda() + timest = time.time() -print("model loaded in {}s".format(round(time.time() - timest, 3))) + # model = torch.hub.load('pytorch/vision:v0.8.0', 'wide_resnet101_2', pretrained=True).eval().cuda() + model = alexnet(pretrained = True).eval().cuda() + print("model loaded in {}s".format(round(time.time() - timest, 3))) # Load TRT -if trt: +else: print("loding trt model...") timesttrt = time.time() - try: - model = torch.load(MODEL_TRT_PATH) + try: # Load from file + model_trt = TRTModule() + model_trt.load_state_dict(torch.load(MODEL_TRT_PATH)) - except FileNotFoundError: + except FileNotFoundError: # Convert from regular print("converting torch to trt...") x = torch.ones((1, 3, 224, 224)).cuda() timest = time.time() - model_trt = torch2trt(model, [x]) + model_trt = torch2trt(alexnet(pretrained = True).eval().cuda(), [x]) - torch.save(model_trt, MODEL_TRT_PATH) + torch.save(model_trt.state_dict(), MODEL_TRT_PATH) print("converted in {}s".format(round(time.time() - timest, 3)))