Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
a0405u committed Dec 29, 2020
1 parent 4880dd8 commit 0bc2346
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions lab3/lab3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch2trt import torch2trt
from torch2trt import TRTModule
# from torchvision.models.alexnet import alexnet
from torchvision.models.alexnet import alexnet
from torch.autograd import Variable
from torchvision import transforms
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -30,8 +30,8 @@

timest = time.time()

model = torch.hub.load('pytorch/vision:v0.8.0', 'wide_resnet101_2', pretrained=True).eval().cuda()
# model = alexnet(pretrained = True).eval().cuda()
# model = torch.hub.load('pytorch/vision:v0.8.0', 'wide_resnet101_2', pretrained=True).eval().cuda()
model = alexnet(pretrained = True).eval().cuda()

print("model loaded in {}s".format(round(time.time() - timest, 2)))

Expand Down Expand Up @@ -93,9 +93,11 @@ def predict(image):
#output = model_trt(input)
output = model(input)

probability = torch.nn.Softmax(output)

print("image processed in {}s".format(round(time.time() - timest, 3)))

return output.data.cpu().numpy().argmax()
return output.data.cpu().numpy().argmax(), torch.topk(probability, 1)


# Process image
Expand All @@ -105,13 +107,13 @@ def process(image):
fig = plt.figure(figsize=(10, 10))
sub = fig.add_subplot(1,1,1)

index = predict(image)
index, probability = predict(image)

print(classes[index])

sub.set_title(classes[index])
sub.set_title(str(classes[index]) + str(probability))
plt.axis('off')
plt.imshow(image)
plt.imshow(image.thumbnail((320, 240)))
plt.savefig('out/' + str(index) + '.png')
# plt.show()

Expand Down

0 comments on commit 0bc2346

Please sign in to comment.