Skip to content

Commit 48922d4

Browse files
committed
Save partial sample images during sampling
1 parent 8c79f9c commit 48922d4

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

main.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ def sample(model, generation_idx, mask_init, mask_undilated, mask_dilated, batch
529529
data = torch.zeros(sample_batch_size, obs[0], obs[1], obs[2])
530530
data = data.cuda()
531531
sample_idx = generation_idx
532+
context = None
533+
batch_to_complete = None
532534
else:
533535
if args.sample_region == "center":
534536
offset1 = -args.sample_size_h // 2
@@ -571,20 +573,29 @@ def sample(model, generation_idx, mask_init, mask_undilated, mask_dilated, batch
571573
print("batch_to_complete", type(batch_to_complete), batch_to_complete.shape, "data", type(data), data.shape)
572574
data[:, :, sample_idx[:, 0], sample_idx[:, 1]] = 0
573575

574-
context = rescaling_inv(data).cpu()
575-
batch_to_complete = rescaling_inv(batch_to_complete).cpu()
576+
context = rescaling_inv(data).cpu()
577+
batch_to_complete = rescaling_inv(batch_to_complete).cpu()
578+
579+
logger.info(f"Example context: {context.numpy()}")
576580

577581
logger.info(f"Before sampling, data has range {data.min().item()}-{data.max().item()} (mean {data.mean().item()}), dtype={data.dtype} {type(data)}")
578-
logger.info(f"Example context: {context.numpy()}")
579-
for i, j in tqdm.tqdm(sample_idx, desc="Sampling pixels"):
582+
for n_pix, (i, j) in enumerate(tqdm.tqdm(sample_idx, desc="Sampling pixels")):
580583
data_v = Variable(data)
584+
t1 = time.time()
581585
out = model(data_v, sample=True, mask_init=mask_init, mask_undilated=mask_undilated, mask_dilated=mask_dilated)
586+
t2 = time.time()
582587
out_sample = sample_op(out).data[:, :, i, j]
588+
logger.info("%d %d,%d Time to infer logits=%f s, sample=%f s", n_pix, i, j, t2-t1, time.time()-t2)
583589
data[:, :, i, j] = out_sample
584590
logger.info(f"Sampled pixel {i},{j}, with batchwise range {out_sample.min().item()}-{out_sample.max().item()} (mean {out_sample.mean().item()}), dtype={out_sample.dtype} {type(out_sample)}")
591+
592+
if (n_pix <= 256 and n_pix % 32 == 0) or n_pix % 256 == 0:
593+
sample_save_path = os.path.join(run_dir, f'{args.mode}_{args.sample_region}_{args.sample_size_h}x{args.sample_size_w}_o1{args.sample_offset1}_o2{args.sample_offset2}_obs{obs2str(obs)}_ep{checkpoint_epochs}_order{sample_order_i}_{n_pix}of{len(sample_idx)}pix.png')
594+
utils.save_image(rescaling_inv(data), sample_save_path, nrow=4, padding=5, pad_value=1, scale_each=False)
595+
wandb.log({sample_save_path: wandb.Image(sample_save_path)}, step=n_pix)
585596
data = rescaling_inv(data).cpu()
586597

587-
if batch_to_complete is not None:
598+
if batch_to_complete is not None and context is not None:
588599
# Interleave along batch dimension to visualize GT images
589600
difference = torch.abs(data - batch_to_complete)
590601
logger.info(f"Context range {context.min()}-{context.max()}. Data range {data.min()}-{data.max()}. batch_to_complete range {batch_to_complete.min()}-{batch_to_complete.max()}")

utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def sample_from_discretized_mix_logistic(l, nr_mix, mixture_temperature=1.0, log
286286
temp.uniform_(1e-5, 1. - 1e-5)
287287
temp = logit_probs.data - torch.log(- torch.log(temp))
288288
_, argmax = temp.max(dim=3)
289-
289+
290290
one_hot = to_one_hot(argmax, nr_mix)
291291
sel = one_hot.view(xs[:-1] + [1, nr_mix])
292292
# select logistic parameters

0 commit comments

Comments
 (0)