-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_hyper.py
66 lines (61 loc) · 2.71 KB
/
test_hyper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from data.dataloader import ImageDataset
import torch
from torchvision import transforms
from models.elic import TestModel as ELICModel
from models.fid import fid_pytorch, cal_psnr
from models.utils import print_avgs
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import compressai
import compressai.zoo
from models.gg18 import ScaleHyperpriorSTE
from guided_diffusion.measurements import gg18_paths
test_path = 'data/ffhq_samples/'
lambdas = [1,2,3,5,6]
Ns, Ms = [128,128,128,128,128,192,192,192], [192,192,192,192,192,320,320,320]
for lam in lambdas:
model = compressai.zoo.bmshj2018_hyperprior(lam, metric='mse', pretrained=True, progress=True, )
test_transforms = transforms.Compose(
[transforms.ToTensor()]
)
test_dataset = ImageDataset(root=test_path, transform=test_transforms)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True)
codec = ScaleHyperpriorSTE(Ns[lam - 1], Ms[lam - 1])
codec.load_state_dict_gg18(torch.load(gg18_paths[lam - 1]))
codec = codec.cuda()
codec.eval()
fid_computer = fid_pytorch()
avgs = {
"bpp": [], "y_bpp": [], "z_bpp": [],
"mse": [], "psnr": [],
"fid": []
}
with torch.no_grad():
fid_computer.clear_pools()
for i, (x, x_name) in tqdm(enumerate(test_dataloader)):
x = x.cuda()
b, c, h, w = x.shape
num_pix = h*w
# encode
out = codec(x)
y_bpp = torch.mean(torch.sum(-torch.log2(out["likelihoods"]["y"]),dim=(1,2,3)), dim=0) / num_pix
z_bpp = torch.mean(torch.sum(-torch.log2(out["likelihoods"]["z"]),dim=(1,2,3)), dim=0) / num_pix
x_hat = out["x_bar"]
unfold = nn.Unfold(kernel_size=(64, 64),stride=(64, 64))
x1_unfold = unfold(x).reshape(1, 3, 64, 64, -1)
x1_unfold = torch.permute(x1_unfold, (0, 4, 1, 2, 3)).reshape(-1, 3, 64, 64)
x2_unfold = unfold(x_hat).reshape(1, 3, 64, 64, -1)
x2_unfold = torch.permute(x2_unfold, (0, 4, 1, 2, 3)).reshape(-1, 3, 64, 64)
fid_computer.add_ref_img(x1_unfold)
fid_computer.add_dis_img(x2_unfold)
# statistics
avgs['bpp'].append(y_bpp.item() + z_bpp.item())
avgs['y_bpp'].append(y_bpp.item())
avgs['z_bpp'].append(z_bpp.item())
avgs['mse'].append(torch.mean((x - x_hat)**2).item())
avgs['psnr'].append(cal_psnr(x, x_hat))
# compute fid in 256 patches to ensure enough images is available
# flush results if being told so
avgs['fid'].append(fid_computer.summary_pools())
print_avgs(avgs)