-
Notifications
You must be signed in to change notification settings - Fork 84
/
dtu.py
215 lines (178 loc) · 8.95 KB
/
dtu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from torch.utils.data import Dataset
from utils import read_pfm
import os
import numpy as np
import cv2
from PIL import Image
import torch
from torchvision import transforms as T
import torchvision.transforms.functional as F
def colorjitter(img, factor):
# brightness_factor,contrast_factor,saturation_factor,hue_factor
# img = F.adjust_brightness(img, factor[0])
# img = F.adjust_contrast(img, factor[1])
img = F.adjust_saturation(img, factor[2])
img = F.adjust_hue(img, factor[3]-1.0)
return img
class MVSDatasetDTU(Dataset):
def __init__(self, root_dir, split, n_views=3, levels=1, img_wh=None, downSample=1.0, max_len=-1):
"""
img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
"""
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'val', 'test'], \
'split must be either "train", "val" or "test"!'
self.img_wh = img_wh
self.downSample = downSample
self.scale_factor = 1.0 / 200
self.max_len = max_len
if img_wh is not None:
assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
'img_wh must both be multiples of 32!'
self.build_metas()
self.n_views = n_views
self.levels = levels # FPN levels
self.build_proj_mats()
self.define_transforms()
print(f'==> image down scale: {self.downSample}')
def define_transforms(self):
self.transform = T.Compose([T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def build_metas(self):
self.metas = []
with open(f'configs/lists/dtu_{self.split}_all.txt') as f:
self.scans = [line.rstrip() for line in f.readlines()]
# light conditions 0-6 for training
# light condition 3 for testing (the brightest?)
light_idxs = [3] if 'train' != self.split else range(7)
self.id_list = []
for scan in self.scans:
with open(f'configs/dtu_pairs.txt') as f:
num_viewpoint = int(f.readline())
# viewpoints (49)
for _ in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
for light_idx in light_idxs:
self.metas += [(scan, light_idx, ref_view, src_views)]
self.id_list.append([ref_view] + src_views)
self.id_list = np.unique(self.id_list)
self.build_remap()
def build_proj_mats(self):
proj_mats, intrinsics, world2cams, cam2worlds = [], [], [], []
for vid in self.id_list:
proj_mat_filename = os.path.join(self.root_dir,
f'Cameras/train/{vid:08d}_cam.txt')
intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
intrinsic[:2] *= 4
extrinsic[:3, 3] *= self.scale_factor
intrinsic[:2] = intrinsic[:2] * self.downSample
intrinsics += [intrinsic.copy()]
# multiply intrinsics and extrinsics to get projection matrix
proj_mat_l = np.eye(4)
intrinsic[:2] = intrinsic[:2] / 4
proj_mat_l[:3, :4] = intrinsic @ extrinsic[:3, :4]
proj_mats += [(proj_mat_l, near_far)]
world2cams += [extrinsic]
cam2worlds += [np.linalg.inv(extrinsic)]
self.proj_mats, self.intrinsics = np.stack(proj_mats), np.stack(intrinsics)
self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds)
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
# depth_min & depth_interval: line 11
depth_min = float(lines[11].split()[0]) * self.scale_factor
depth_max = depth_min + float(lines[11].split()[1]) * 192 * self.scale_factor
self.depth_interval = float(lines[11].split()[1])
return intrinsics, extrinsics, [depth_min, depth_max]
def read_depth(self, filename):
depth_h = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
depth_h = cv2.resize(depth_h, None, fx=0.5, fy=0.5,
interpolation=cv2.INTER_NEAREST) # (600, 800)
depth_h = depth_h[44:556, 80:720] # (512, 640)
depth_h = cv2.resize(depth_h, None, fx=self.downSample, fy=self.downSample,
interpolation=cv2.INTER_NEAREST) # !!!!!!!!!!!!!!!!!!!!!!!!!
depth = cv2.resize(depth_h, None, fx=1.0 / 4, fy=1.0 / 4,
interpolation=cv2.INTER_NEAREST) # !!!!!!!!!!!!!!!!!!!!!!!!!
mask = depth > 0
return depth, mask, depth_h
def build_remap(self):
self.remap = np.zeros(np.max(self.id_list) + 1).astype('int')
for i, item in enumerate(self.id_list):
self.remap[item] = i
def __len__(self):
return len(self.metas) if self.max_len <= 0 else self.max_len
def __getitem__(self, idx):
sample = {}
scan, light_idx, target_view, src_views = self.metas[idx]
if self.split=='train':
ids = torch.randperm(5)[:3]
view_ids = [src_views[i] for i in ids] + [target_view]
else:
view_ids = [src_views[i] for i in range(3)] + [target_view]
affine_mat, affine_mat_inv = [], []
imgs, depths_h = [], []
proj_mats, intrinsics, w2cs, c2ws, near_fars = [], [], [], [], [] # record proj mats between views
for i, vid in enumerate(view_ids):
# NOTE that the id in image file names is from 1 to 49 (not 0~48)
img_filename = os.path.join(self.root_dir,
f'Rectified/{scan}_train/rect_{vid + 1:03d}_{light_idx}_r5000.png')
depth_filename = os.path.join(self.root_dir,
f'Depths/{scan}/depth_map_{vid:04d}.pfm')
img = Image.open(img_filename)
img_wh = np.round(np.array(img.size) * self.downSample).astype('int')
img = img.resize(img_wh, Image.BILINEAR)
img = self.transform(img)
imgs += [img]
index_mat = self.remap[vid]
proj_mat_ls, near_far = self.proj_mats[index_mat]
intrinsics.append(self.intrinsics[index_mat])
w2cs.append(self.world2cams[index_mat])
c2ws.append(self.cam2worlds[index_mat])
affine_mat.append(proj_mat_ls)
affine_mat_inv.append(np.linalg.inv(proj_mat_ls))
if i == 0: # reference view
ref_proj_inv = np.linalg.inv(proj_mat_ls)
proj_mats += [np.eye(4)]
else:
proj_mats += [proj_mat_ls @ ref_proj_inv]
if os.path.exists(depth_filename):
depth, mask, depth_h = self.read_depth(depth_filename)
depth_h *= self.scale_factor
depths_h.append(depth_h)
else:
depths_h.append(np.zeros((1, 1)))
near_fars.append(near_far)
imgs = torch.stack(imgs).float()
# if self.split == 'train':
# imgs = colorjitter(imgs, 1.0+(torch.rand((4,))*2-1.0)*0.5)
# imgs = F.normalize(imgs,mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
depths_h = np.stack(depths_h)
proj_mats = np.stack(proj_mats)[:, :3]
affine_mat, affine_mat_inv = np.stack(affine_mat), np.stack(affine_mat_inv)
intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(w2cs), np.stack(c2ws), np.stack(near_fars)
view_ids_all = [target_view] + list(src_views) if type(src_views[0]) is not list else [j for sub in src_views for j in sub]
c2ws_all = self.cam2worlds[self.remap[view_ids_all]]
sample['images'] = imgs # (V, H, W, 3)
sample['depths_h'] = depths_h.astype(np.float32) # (V, H, W)
sample['w2cs'] = w2cs.astype(np.float32) # (V, 4, 4)
sample['c2ws'] = c2ws.astype(np.float32) # (V, 4, 4)
sample['near_fars'] = near_fars.astype(np.float32)
sample['proj_mats'] = proj_mats.astype(np.float32)
sample['intrinsics'] = intrinsics.astype(np.float32) # (V, 3, 3)
sample['view_ids'] = np.array(view_ids)
sample['light_id'] = np.array(light_idx)
sample['affine_mat'] = affine_mat
sample['affine_mat_inv'] = affine_mat_inv
sample['scan'] = scan
sample['c2ws_all'] = c2ws_all.astype(np.float32)
return sample