forked from mit-han-lab/gan-compression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_real_stat.py
85 lines (72 loc) · 3.59 KB
/
get_real_stat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import warnings
import numpy as np
import torch
import tqdm
from data import create_dataloader
from metric.fid_score import _compute_statistics_of_ims
from metric.inception import InceptionV3
from utils import util
def main(opt):
dataloader = create_dataloader(opt)
device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids \
else torch.device('cpu')
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
inception_model = InceptionV3([block_idx])
inception_model.to(device)
inception_model.eval()
tensors = []
for i, data_i in enumerate(tqdm.tqdm(dataloader)):
tensor = data_i[opt.direction[-1]]
tensors.append(tensor)
tensors = torch.cat(tensors, dim=0)
tensors = util.tensor2im(tensors).astype(float)
mu, sigma = _compute_statistics_of_ims(tensors, inception_model, 32, 2048, device, use_tqdm=True)
np.savez(opt.output_path, mu=mu, sigma=sigma)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Extract some statistical information of a dataset to compute FID')
parser.add_argument('--input_nc', type=int, default=3,
help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3,
help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--dataroot', required=True,
help='path to images (should have subfolders trainA, trainB, valA, valB, train, val, etc)')
parser.add_argument('--dataset_mode', type=str, default='aligned',
help='chooses how datasets are loaded. [aligned | single]')
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--preprocess', type=str, default='resize_and_crop',
help='scaling and cropping of images at load time '
'[resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
parser.add_argument('--output_path', type=str, required=True,
help='the path to save the statistical information.')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
opt = parser.parse_args()
opt.num_threads = 0
opt.batch_size = 1
opt.serial_batches = True
opt.no_flip = True
opt.max_dataset_size = -1
opt.load_in_memory = False
if opt.dataset_mode == 'single' and opt.direction == 'AtoB':
warnings.warn('Dataset mode [single] only supports direction BtoA. '
'We will change the direction to BtoA.!')
opt.direction = 'BtoA'
def parse_gpu_ids(str_ids):
str_ids = str_ids.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
gpu_ids.append(id)
if len(gpu_ids) > 0:
torch.cuda.set_device(gpu_ids[0])
return gpu_ids
opt.gpu_ids = parse_gpu_ids(opt.gpu_ids)
if not opt.output_path.endswith('.npz'):
warnings.warn('The output is a numpy npz file, but the output path does\'nt end with ".npz".')
if len(opt.gpu_ids) > 1:
warnings.warn('The code only supports single GPU. Only gpu [%d] will be used.' % opt.gpu_ids[0])
main(opt)