Skip to content

Commit

Permalink
Added LPIPS
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Nov 7, 2021
1 parent 05d0c14 commit dcb76c1
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions opt/render_imgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import imageio
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('ckpt', type=str)

Expand All @@ -39,10 +38,18 @@
action='store_true',
default=False,
help="Force a black BG (behind BG model) color; useful for debugging 'clouds'")
parser.add_argument('--no_lpips',
action='store_true',
default=False,
help="Disable LPIPS")
args = parser.parse_args()
config_util.maybe_merge_config_file(args)
device = 'cuda:0'

if not args.no_lpips:
import lpips
lpips_vgg = lpips.LPIPS(net="vgg").eval().to(device)

render_dir = path.join(path.dirname(args.ckpt),
'train_renders' if args.train else 'test_renders')
if args.render_path:
Expand Down Expand Up @@ -85,6 +92,7 @@
img_eval_interval = max(n_images // args.n_eval, 1)
avg_psnr = 0.0
avg_ssim = 0.0
avg_lpips = 0.0
n_images_gen = 0
cam = svox2.Camera(torch.tensor(0), dset.intrins.fx, dset.intrins.fy,
dset.intrins.cx, dset.intrins.cy,
Expand All @@ -100,10 +108,16 @@
mse = (im - im_gt) ** 2
mse_num : float = mse.mean().item()
psnr = -10.0 * math.log10(mse_num)
ssim = compute_ssim(im_gt, im)
ssim = compute_ssim(im_gt, im).item()
avg_psnr += psnr
avg_ssim += ssim
print(img_id, 'PSNR', psnr, 'SSIM', ssim)
if not args.no_lpips:
lpips_i = lpips_vgg(im_gt.permute([2, 0, 1]).contiguous(),
im.permute([2, 0, 1]).contiguous(), normalize=True).item()
avg_lpips += lpips_i
print(img_id, 'PSNR', psnr, 'SSIM', ssim, 'LPIPS', lpips_i)
else:
print(img_id, 'PSNR', psnr, 'SSIM', ssim)
# all_rgbs = []
# all_mses = []
# for batch_begin in range(0, im_size, args.eval_batch_size):
Expand Down Expand Up @@ -132,3 +146,8 @@
f.write(str(avg_psnr))
with open(path.join(render_dir, 'ssim.txt'), 'w') as f:
f.write(str(avg_ssim))
if not args.no_lpips:
avg_lpips /= n_images_gen
print('average LPIPS', avg_lpips)
with open(path.join(render_dir, 'lpips.txt'), 'w') as f:
f.write(str(avg_lpips))

0 comments on commit dcb76c1

Please sign in to comment.