-
Notifications
You must be signed in to change notification settings - Fork 95
/
test.py
85 lines (72 loc) · 2.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
file - test.py
Simple quick script to evaluate model on test images.
Copyright (C) Yunxiao Shi 2017 - 2021
NIMA is released under the MIT license. See LICENSE for the fill license text.
"""
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from tqdm import tqdm
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from model.model import *
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='path to pretrained model')
parser.add_argument('--test_csv', type=str, help='test csv file')
parser.add_argument('--test_images', type=str, help='path to folder containing images')
parser.add_argument('--workers', type=int, default=4, help='number of workers')
parser.add_argument('--predictions', type=str, help='output file to store predictions')
args = parser.parse_args()
base_model = models.vgg16(pretrained=True)
model = NIMA(base_model)
try:
model.load_state_dict(torch.load(args.model))
print('successfully loaded model')
except:
raise
seed = 42
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
test_transform = transforms.Compose([
transforms.Scale(256),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
test_df = pd.read_csv(args.test_csv, header=None)
test_imgs = test_df[0]
pbar = tqdm(total=len(test_imgs))
mean, std = 0.0, 0.0
for i, img in enumerate(test_imgs):
im = Image.open(os.path.join(args.test_images, str(img) + '.jpg'))
im = im.convert('RGB')
imt = test_transform(im)
imt = imt.unsqueeze(dim=0)
imt = imt.to(device)
with torch.no_grad():
out = model(imt)
out = out.view(10, 1)
for j, e in enumerate(out, 1):
mean += j * e
for k, e in enumerate(out, 1):
std += e * (k - mean) ** 2
std = std ** 0.5
gt = test_df[test_df[0] == img].to_numpy()[:, 1:].reshape(10, 1)
gt_mean = 0.0
for l, e in enumerate(gt, 1):
gt_mean += l * e
# print(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f' % (mean, std, gt_mean))
if not os.path.exists(args.predictions):
os.makedirs(args.predictions)
with open(os.path.join(args.predictions, 'pred.txt'), 'a') as f:
f.write(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f\n' % (mean, std, gt_mean))
mean, std = 0.0, 0.0
pbar.update()