-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpredict.py
More file actions
139 lines (112 loc) · 4.88 KB
/
predict.py
File metadata and controls
139 lines (112 loc) · 4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import os
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
import time
import PIL
from PIL import Image
import numpy as np
from utils import IOUtils, TypeUtils
def parseInput():
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument('-i', '--image', action='store',
help="Path to image to predict")
parser.add_argument('-c', '--checkpoint', action='store',
help="Model checkpoint to use for prediction")
parser.add_argument('-tk', '--top_k', action='store',
help="Return number of most likely predictions")
parser.add_argument('--category_names', action='store',
help="Dictionary to map category indices to names")
parser.add_argument('-gpu', '--gpu', action='store_true',
help="Use GPU if it's available")
args = parser.parse_args()
required_arguments = ['image', 'checkpoint']
for arg in required_arguments:
if not vars(args)[arg]:
IOUtils.notify(f"You are missing a required argument: {arg}")
exit()
return {
'image': args.image,
'checkpoint': args.checkpoint,
'top_k': args.top_k,
'category_names': args.category_names,
'shouldTryGPU': args.gpu
}
class Predictor:
def __init__(self, *args, **kwargs):
self.image = kwargs.get('image')
self.checkpoint = kwargs.get('checkpoint')
self.top_k = kwargs.get('top_k')
self.category_names = kwargs.get('category_names')
if kwargs.get('shouldTryGPU'):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
IOUtils.notify("ALERT: You are not predicting using GPU. Predictions may take a long time. To use GPU, call the function with the -gpu flag.")
self.device = torch.device("cpu")
def predictImage(self, *args, **kwargs):
self.preparePredictor()
self.loadModel()
self.predict()
def loadModel(self, *args, **kwargs):
checkpoint = torch.load(self.checkpoint)
self.model = checkpoint['architecture']
for param in self.model.parameters():
param.requires_grad = False
self.model.classifier = checkpoint['classifier']
self.model.load_state_dict(checkpoint['state_dict'])
self.model.to(self.device)
return self.model
def preparePredictor(self, *args, **kwargs):
if not self.category_names:
IOUtils.notify("category_names dictionary not provided. Predictions will show indexes unless you pass in a valid category_names dict at runtime.")
if not self.top_k:
IOUtils.notify("top_k not provided. Showing top 5 category predictions by default.")
self.top_k = 5
def process_image(self, image):
# Preprocess images to turn them into valid inputs into our model
img = Image.open(image)
width,height = img.size
## If width > height, resize height to 256. Else resize width to 256
if width > height:
img.thumbnail(size=(width,256))
else:
img.thumbnail(size=(256,height))
## Get 224x224 crop coordinates
crop = {
'left': (img.width / 2 - 112),
'top': (img.height / 2 - 112),
'right': (img.width / 2 + 112),
'bottom': (img.height / 2 + 112)
}
img = img.crop((crop['left'], crop['top'], crop['right'], crop['bottom']))
np_img = np.array(img)/255 # Convert to floats between 0 and 1
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
np_img = (np_img - mean)/std
# Transpose because PyTorch expects the color channel to be the first dimension but it's the third dimension in the PIL image and Numpy array.
return np_img.transpose(2,0,1) #set color to first dimension
def predict(self, *args, **kwargs):
self.model.to(self.device)
self.model.eval()
processed_img = self.process_image(self.image)
torch_img = torch.from_numpy(np.expand_dims(processed_img, axis=0)).type(torch.FloatTensor).to(self.device)
with torch.no_grad():
output = self.model.forward(torch_img)
prob = torch.exp(output)
probs, indexes = prob.topk(int(self.top_k))
probs_list = np.array(probs)[0]
predictions_list = np.array(indexes)[0]
if self.category_names:
#If they provided a category_names dict, use it to translate indexes into names now.
classes_list = []
category_indexes_dict = TypeUtils.cat_to_name(self.category_names)
for idx in predictions_list:
classes_list.append(category_indexes_dict[str(idx + 1)])
predictions_list = classes_list
for i in range(len(probs_list)):
IOUtils.notify(f"Prediction: {predictions_list[i]} with probability {(probs_list[i] * 100):.2f}%")
return probs_list, predictions_list
predictorInputs = parseInput()
predictor = Predictor(**predictorInputs)
predictor.predictImage()