Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
a0405u committed Jan 26, 2021
1 parent b3db136 commit e68e4e9
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions lab3/lab3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,40 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load regular model

# Create regular pytorch model
if not trt:

print("loading model...")
print("loading model...")

timest = time.time()

# model = torch.hub.load('pytorch/vision:v0.8.0', 'wide_resnet101_2', pretrained=True).eval().cuda()
model = alexnet(pretrained = True).eval().cuda()
timest = time.time()

print("model loaded in {}s".format(round(time.time() - timest, 3)))
# model = torch.hub.load('pytorch/vision:v0.8.0', 'wide_resnet101_2', pretrained=True).eval().cuda()
model = alexnet(pretrained = True).eval().cuda()

print("model loaded in {}s".format(round(time.time() - timest, 3)))

# Load TRT

if trt:
else:

print("loding trt model...")
timesttrt = time.time()

try:
model = torch.load(MODEL_TRT_PATH)
try: # Load from file
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(MODEL_TRT_PATH))

except FileNotFoundError:
except FileNotFoundError: # Convert from regular

print("converting torch to trt...")

x = torch.ones((1, 3, 224, 224)).cuda()

timest = time.time()
model_trt = torch2trt(model, [x])
model_trt = torch2trt(alexnet(pretrained = True).eval().cuda(), [x])

torch.save(model_trt, MODEL_TRT_PATH)
torch.save(model_trt.state_dict(), MODEL_TRT_PATH)

print("converted in {}s".format(round(time.time() - timest, 3)))

Expand Down

0 comments on commit e68e4e9

Please sign in to comment.