-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_thumos.py
142 lines (116 loc) · 5.03 KB
/
eval_thumos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
import os
import requests
import pickle
import io
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from scipy.interpolate import interp1d
from evaluation_thumos import prop_eval
def run_evaluation(proposal_filename, groundtruth_filename='../datasets/thumos14/thumos14_test_groundtruth.csv'):
frm_nums = pickle.load(open("evaluation_thumos/frm_num.pkl", 'rb'))
rows = prop_eval.pkl2dataframe(frm_nums, 'evaluation_thumos/movie_fps.pkl', proposal_filename)
aen_results = pd.DataFrame(rows, columns=['f-end', 'f-init', 'score', 'video-frames', 'video-name'])
# Retrieves and loads Thumos14 test set ground-truth.
if not os.path.isfile(groundtruth_filename):
ground_truth_url = ('https://gist.githubusercontent.com/cabaf/'
'ed34a35ee4443b435c36de42c4547bd7/raw/'
'952f17b9cdc6aa4e6d696315ba75091224f5de97/'
'thumos14_test_groundtruth.csv')
s = requests.get(ground_truth_url).content
groundtruth = pd.read_csv(io.StringIO(s.decode('utf-8')), sep=' ')
groundtruth.to_csv(groundtruth_filename)
else:
groundtruth = pd.read_csv(groundtruth_filename)
# Computes recall for different tiou thresholds at a fixed average number of proposals.
'''
recall, tiou_thresholds = prop_eval.recall_vs_tiou_thresholds(aen_results, ground_truth,
nr_proposals=nr_proposals,
tiou_thresholds=np.linspace(0.5, 1.0, 11))
recall = np.mean(recall)
'''
average_recall, average_nr_proposals = prop_eval.average_recall_vs_nr_proposals(aen_results, groundtruth)
return average_recall, average_nr_proposals
def evaluate_proposals(cfg, nr_proposals_list=(50, 100, 200, 500, 1000)):
average_recall, average_nr_proposals = run_evaluation(cfg.DATA.RESULT_PATH)
f = interp1d(average_nr_proposals, average_recall, axis=0, bounds_error=False, fill_value='extrapolate')
ar_results = {}
for nr_prop in nr_proposals_list:
ar_results[nr_prop] = float(f(nr_prop))
print("AR@{} is {}\n".format(nr_prop, ar_results[nr_prop]))
return ar_results[100]
def plot_metric(average_nr_proposals, recalls, labels, colors, linestyles, figure_file):
fn_size = 25
plt.figure(num=None, figsize=(30, 12))
#colors = ['#2CBDFE', '#47DBCD', '#F3A0F2', '#9D2EC5', '#661D98', '#F5B14C']
def plotting(sub_ax, recs, lbs, lnstls, clrs):
for idx, rec in enumerate(recs):
ax.plot(average_nr_proposals, rec, color=clrs[idx],
label=lbs[idx],
linewidth=6, linestyle=lnstls[idx], marker=None)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc='lower right', fontsize=fn_size)
plt.ylabel('Average Recall', fontsize=fn_size)
plt.xlabel('Average Number of Proposals per Video', fontsize=fn_size)
plt.grid(b=True, which="both")
#plt.ylim([.35, .6])
plt.setp(ax.get_xticklabels(), fontsize=fn_size)
plt.setp(ax.get_yticklabels(), fontsize=fn_size)
ax = plt.subplot(1, 2, 1)
plotting(
ax,
recalls[:4],
labels[:4],
linestyles[:4],
colors[:4]
)
ax = plt.subplot(1, 2, 2)
plotting(
ax,
recalls[4:],
labels[4:],
linestyles[4:],
[colors[0]] + colors[4:]
)
# plt.show()
plt.savefig(figure_file, dpi=300)
def main_evaluate_proposals(result_file, nr_proposals_list):
average_recall, average_nr_proposals = run_evaluation(result_file)
f = interp1d(average_nr_proposals, average_recall, axis=0, bounds_error=False, fill_value='extrapolate')
ar_results = []
for nr_prop in nr_proposals_list:
ar_results.append(float(f(nr_prop)))
return ar_results
def main():
result_dir = 'results/ablation_study/'
result_files = [
'full_arch.pkl',
'act_only.pkl',
'env_only.pkl',
'no_interaction.pkl',
'full_arch.pkl',
'env+hard_attn_only.pkl',
'env+self_attn_only.pkl',
]
labels = [
'AEI (all spectators)',
'Actors spectator only',
'Environment spectator only',
'W/o interaction spectator',
'AEI (main actor selection and feature fusion)',
'W/o feature fusion',
'W/o main actor selection',
]
#colors = ['#2f4858', '#55dde0', '#33658a', '#f6ae2d', '#f26419']
#colors = ['#390099', '#9e0059', '#ff0054', '#ff5400', '#ffbd00']
colors = ['tab:red', 'tab:purple', 'tab:green', 'tab:pink', 'tab:blue', 'tab:orange']
linestyles = ['-'] * 7
nr_props = list(range(50, 1000))
ar_results = []
for res_file in tqdm(result_files):
ar_results.append(main_evaluate_proposals(os.path.join(result_dir, res_file), nr_props))
print('Finished evaluating, start plotting!')
plot_metric(nr_props, ar_results, labels, colors, linestyles, 'ablation_study.png')
if __name__ == '__main__':
main()