-
Notifications
You must be signed in to change notification settings - Fork 4
/
gen_fg.py
executable file
·93 lines (70 loc) · 3.23 KB
/
gen_fg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import glob
import argparse
import numpy as np
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from PIL import Image
from model import MattingRefine
def parse_args():
parser = argparse.ArgumentParser(description='Process foreground dataset')
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--root_dir', type=str, required=True)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
return parser.parse_args()
# Worker function
def writer(img, path):
img = to_pil_image(img[0].cpu())
img.save(path)
def proc_fg(root_dir):
device = torch.device(args.device)
# Load model
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)
model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
extensions = ['.jpg', '.png']
# Load images
folders = sorted(glob.glob(os.path.join(root_dir, 'cam*')))
for folder in folders:
# Load background image
bg_name = os.path.split(folder)[-1]
bg_path = os.path.join(root_dir, 'background', bg_name + '*')
bg_path = glob.glob(bg_path)[0]
bg = np.array(Image.open(bg_path)) / 255.
bgr = torch.FloatTensor(bg).permute([2, 0, 1])[None, :]
bgr = bgr.to(device, non_blocking=True)
frames = []
for extension in extensions:
frames += sorted(glob.glob(os.path.join(folder, '*' + extension)))
out_dir = os.path.join(root_dir, 'foreground', bg_name)
os.makedirs(out_dir, exist_ok=True)
with torch.no_grad():
for frame in tqdm(frames):
new_name = os.path.split(frame)[-1][:-4]
im = Image.open(frame)
src = torch.FloatTensor(np.array(im) / 255.).permute([2, 0, 1])[None, :]
src = src.to(device, non_blocking=True)
pha, fgr, _, _, _, _ = model(src, bgr)
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
Thread(target=writer, args=(com, os.path.join(out_dir, new_name + '.png'))).start()
if __name__ == '__main__':
args = parse_args()
scenes = sorted(glob.glob(os.path.join(args.root_dir, 'scene*')))
print(scenes)
for scene in scenes:
proc_fg(scene)