-
Notifications
You must be signed in to change notification settings - Fork 10
/
run_validate_separate.py
66 lines (52 loc) · 1.84 KB
/
run_validate_separate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch, cv2, os
import numpy as np
import arg, models
from learning.validation import compute_metric
from learning.visualization import visualize_image
from data import PairDataset, PortraitDataset
# Data Loader
mode = os.sys.argv[2]
if mode == 'view':
light_size = None
else:
light_size = arg.model_args['light_size']
dataset = PairDataset(
PortraitDataset(
f'{arg.base_path}/data/blender_{mode}', arg.val_data_names,
['source_image', 'target_image'], arg.light_ext
), shuffle=False, light_size=light_size, rotate_ratio=arg.rotate_ratio
)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=None,
num_workers=2, prefetch_factor=1,
pin_memory=True
)
# Load model
model = torch.nn.DataParallel(
models.get_model(arg.model_name)(**arg.model_args)
).cuda()
model.module.load_state_dict(torch.load(f'{arg.ckpt_path}/{arg.train_step}.pth')['model_state_dict'])
# Run validation
val_path = f'{arg.val_path}_{mode}'
os.makedirs(val_path, exist_ok=True)
for val_num, data in enumerate(data_loader):
if mode == 'view':
data['target_light'] = torch.Tensor([0.0])
if mode == 'relight' and os.sys.argv[1] == 'sipr':
data['source_shape'] = data['source_shape'][0]
data['source_image'] = data['source_image'][0]
data['source_mask'] = data['source_mask'][0]
data_cuda = {k: v.cuda() for k, v in data.items() if isinstance(v, torch.Tensor)}
output = model.module.render(False, model, **data_cuda)
if mode == 'view':
del output['light']
image_name = f'{data["data_name"]}_{data["source_image_id"]}_{data["target_image_id"]}'
cv2.imwrite(
f'{val_path}/{image_name}.jpg',
visualize_image(output, **data_cuda)
)
# Validation
psnr, ssim = compute_metric(val_path)
print(arg.exp_name)
print(f'PSNR: {psnr:.4f}')
print(f'SSIM: {ssim:.6f}')