Skip to content

Commit

Permalink
Working on fern, but horrible issues very close to camera
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Oct 28, 2021
1 parent 5452b17 commit cfa0b46
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 129 deletions.
54 changes: 30 additions & 24 deletions opt/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
help='FINAL grid resolution')
group.add_argument('--init_reso', type=int, default=512,
help='INITIAL grid resolution')
group.add_argument('--ref_reso', type=int, default=1024,
group.add_argument('--ref_reso', type=int, default=512,
help='reference grid resolution (for adjusting lr)')
group.add_argument('--z_reso_factor', type=float, default=192/1024,
help='z dimension resolution factor')
Expand All @@ -60,9 +60,9 @@
# TODO: make the lr higher near the end
group.add_argument('--sigma_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="Density optimizer")
group.add_argument('--lr_sigma', type=float, default=
2e1,
3e1,
help='SGD/rmsprop lr for sigma')
group.add_argument('--lr_sigma_final', type=float, default=1e-2)
group.add_argument('--lr_sigma_final', type=float, default=5e-2)
group.add_argument('--lr_sigma_decay_steps', type=int, default=250000)
group.add_argument('--lr_sigma_delay_steps', type=int, default=15000,
help="Reverse cosine steps (0 means disable)")
Expand All @@ -71,11 +71,11 @@

group.add_argument('--sh_optim', choices=['sgd', 'rmsprop'], default='rmsprop', help="SH optimizer")
group.add_argument('--lr_sh', type=float, default=
1e-3,
1e-2,
help='SGD/rmsprop lr for SH')
group.add_argument('--lr_sh_final', type=float,
default=
5e-6
5e-5
)
group.add_argument('--lr_sh_decay_steps', type=int, default=250000)
group.add_argument('--lr_sh_delay_steps', type=int, default=0, help="Reverse cosine steps (0 means disable)")
Expand Down Expand Up @@ -118,7 +118,7 @@
help='sample by permutation of rays (true epoch) instead of '
'uniformly random rays')
group.add_argument('--sigma_thresh', type=float,
default=2.5,
default=1.0,
help='Resample (upsample to 512) sigma threshold')
group.add_argument('--weight_thresh', type=float,
default=0.0005,
Expand All @@ -134,11 +134,12 @@
group.add_argument('--rms_beta', type=float, default=0.9)
group.add_argument('--lambda_tv', type=float, default=1e-3)
group.add_argument('--tv_sparsity', type=float, default=0.01)
group.add_argument('--tv_logalpha', action='store_true', default=False, help='Use log(1-exp(-delta * sigma)) as in neural volumes')

group.add_argument('--lambda_sparsity', type=float, default=0.0)#1e-5)
group.add_argument('--sparsity_sparsity', type=float, default=0.01)

group.add_argument('--lambda_tv_sh', type=float, default=1e-3)
group.add_argument('--lambda_tv_sh', type=float, default=1e-2)
group.add_argument('--tv_sh_sparsity', type=float, default=0.01)

group.add_argument('--lambda_tv_basis', type=float, default=0.0)
Expand All @@ -147,6 +148,7 @@
group.add_argument('--weight_decay_sh', type=float, default=1.0)

group.add_argument('--lr_decay', action='store_true', default=True)
group.add_argument('--last_sample_opaque', action='store_true', default=True)
args = parser.parse_args()

assert args.lr_sigma_final <= args.lr_sigma, "lr_sigma must be >= lr_sigma_final"
Expand Down Expand Up @@ -187,21 +189,18 @@
basis_reso=args.basis_reso,
use_learned_basis=False)

if hasattr(dset, 'z_bounds'):
print('Setting bounds', dset.z_bounds)
grid.set_frustum_bounds(dset.z_bounds[0], dset.z_bounds[1])
print(' ', grid._z_ratio)
grid.opt.last_sample_opaque = True
grid.opt.last_sample_opaque = args.last_sample_opaque

# DC -> gray; mind the SH scaling!
grid.sh_data.data[:] = 0.0
# grid.sh_data.data.normal_(mean=0.0, std=0.001)
grid.density_data.data[:] = args.init_sigma

# grid.sh_data.data[:, 0] = 4.0
# osh = grid.density_data.data.shape
# den = grid.density_data.data.view(grid.links.shape)
# den[:] = 0.01
# den[:, :, -1] = 1e9
# # den[:] = 0.00
# # den[:, :256, :] = 1e9
# # den[:, :, 0] = 1e9
# grid.density_data.data = den.view(osh)

if grid.use_learned_basis:
Expand All @@ -228,8 +227,9 @@
dset.intrins.fy,
dset.intrins.cx,
dset.intrins.cy,
dset.w,
dset.h) for c2w in dset.c2w
width=dset.w,
height=dset.h,
ndc_coeffs=dset.ndc_coeffs) for c2w in dset.c2w
] if args.use_weight_thresh else None
ckpt_path = path.join(args.train_dir, 'ckpt.npz')

Expand Down Expand Up @@ -279,8 +279,9 @@ def eval_step():
dset_test.intrins.fy,
dset_test.intrins.cx,
dset_test.intrins.cy,
dset_test.w,
dset_test.h)
width=dset_test.w,
height=dset_test.h,
ndc_coeffs=dset_test.ndc_coeffs)
rgb_pred_test = grid.volume_render_image(cam, use_kernel=True)
rgb_gt_test = dset_test.gt[img_id].to(device=device)
all_mses = ((rgb_gt_test - rgb_pred_test) ** 2).cpu()
Expand All @@ -289,6 +290,9 @@ def eval_step():
img_pred.clamp_max_(1.0)
summary_writer.add_image(f'test/image_{img_id:04d}',
img_pred, global_step=gstep_id_base, dataformats='HWC')
mse_img = all_mses / all_mses.max()
summary_writer.add_image(f'test/mse_map_{img_id:04d}',
mse_img, global_step=gstep_id_base, dataformats='HWC')

rgb_pred_test = rgb_gt_test = None
mse_num : float = all_mses.mean().item()
Expand Down Expand Up @@ -367,7 +371,7 @@ def train_step():
stats[stat_name] = 0.0
if args.lambda_tv > 0.0:
with torch.no_grad():
tv = grid.tv()
tv = grid.tv(logalpha=args.tv_logalpha)
summary_writer.add_scalar("loss_tv", tv, global_step=gstep_id)
if args.lambda_sparsity > 0.0:
with torch.no_grad():
Expand All @@ -393,7 +397,8 @@ def train_step():
if args.lambda_tv > 0.0:
grid.inplace_tv_grad(grid.density_data.grad,
scaling=args.lambda_tv,
sparse_frac=args.tv_sparsity)
sparse_frac=args.tv_sparsity,
logalpha=args.tv_logalpha)
if args.lambda_sparsity > 0.0:
# Overkill
grid.inplace_sparsity_grad(grid.density_data.grad,
Expand Down Expand Up @@ -437,15 +442,16 @@ def train_step():
reso, reso, int(reso * args.z_reso_factor)],
sigma_thresh=args.sigma_thresh if use_sparsify else 0.0,
weight_thresh=args.weight_thresh if use_sparsify else 0.0,
dilate=1, #use_sparsify,
dilate=2, #use_sparsify,
cameras=resample_cameras)
if non_final:
# if reso <= args.ref_reso:
# lr_sigma_factor *= 8
# else:
# lr_sigma_factor *= 4
lr_sh_factor *= args.lr_sh_upscale_factor
print('Increased lr to (sigma:)', args.lr_sigma, '(sh:)', args.lr_sh)
if args.lr_sh_upscale_factor > 1:
lr_sh_factor *= args.lr_sh_upscale_factor
print('Increased lr to (sigma:)', args.lr_sigma, '(sh:)', args.lr_sh)

if factor > 1 and reso < args.final_reso:
factor //= 2
Expand Down
88 changes: 59 additions & 29 deletions opt/util/ff_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
from typing import Union


def convert_to_ndc(origins, directions, ndc_coeffs, near: float = 1.0):
"""Convert a set of rays to NDC coordinates."""
# Shift ray origins to near plane, not sure if needed
t = (near - origins[Ellipsis, 2]) / directions[Ellipsis, 2]
origins = origins + t[Ellipsis, None] * directions

dx, dy, dz = directions.unbind(-1)
ox, oy, oz = origins.unbind(-1)

# Projection
o0 = ndc_coeffs[0] * (ox / oz)
o1 = ndc_coeffs[1] * (oy / oz)
o2 = 1 - 2 * near / oz

d0 = ndc_coeffs[0] * (dx / dz - ox / oz)
d1 = ndc_coeffs[1] * (dy / dz - oy / oz)
d2 = 2 * near / oz;

origins = torch.stack([o0, o1, o2], -1)
directions = torch.stack([d0, d1, d2], -1)
return origins, directions


class LLFFDataset(Dataset):
def __init__(
self,
Expand All @@ -46,8 +69,8 @@ def __init__(
invz : int= 0,
transform=None,
render_style="",
offset=200,
hold_every=8,
offset=250,
):
self.scale = scale
self.dataset = root
Expand Down Expand Up @@ -89,6 +112,8 @@ def __init__(
self.sfm.ref_cam['px'],
self.sfm.ref_cam['py'])

self.ndc_coeffs = (2 * self.intrins_full.fx / self.w_full,
2 * self.intrins_full.fy / self.h_full)
if self.split == "train":
self.gen_rays(factor=factor)
else:
Expand All @@ -99,15 +124,6 @@ def __init__(

def _load_images(self):
scale = self.scale
img_dir_name = "images_4"
use_integral_scaling = False
scaled_img_dir = ''
if scale != 1 and abs((1.0 / scale) - round(1.0 / scale)) < 1e-9:
# Integral scaling
scaled_img_dir = "images_" + str(round(1.0 / scale))
if os.path.isdir(os.path.join(self.dataset, scaled_img_dir)):
use_integral_scaling = True
print('Using pre-scaled images from', os.path.join(self.dataset, scaled_img_dir))

all_gt = []
all_c2w = []
Expand All @@ -123,8 +139,6 @@ def _load_images(self):
c2w = global_w2rc @ c2w
all_c2w.append(torch.from_numpy(c2w.astype(np.float32)))

if use_integral_scaling:
img_path = scaled_img_dir + '/' + '/'.join(img_path.split('/')[1:])
img_path = os.path.join(self.dataset, img_path)
if not os.path.exists(img_path):
path_noext = os.path.splitext(img_path)[0]
Expand All @@ -133,7 +147,7 @@ def _load_images(self):
if os.path.exists(path_noext + '.png'):
img_path = path_noext + '.png'
img = imageio.imread(img_path)
if scale != 1 and not use_integral_scaling:
if scale != 1 and not self.sfm.use_integral_scaling:
h, w = img.shape[:2]
if self.sfm.dataset_type == "deepview":
newh = int(h * scale) # always floor down height
Expand All @@ -148,31 +162,26 @@ def _load_images(self):
# Apply alpha channel
self.gt = self.gt[..., :3] * self.gt[..., 3:] + (1.0 - self.gt[..., 3:])
self.c2w = torch.stack(all_c2w)
self.z_bounds = [self.sfm.dmin, self.sfm.dmax]

bds_scale = 1.0 / (self.sfm.dmin * 0.75) # 0.9
# bds_scale = 1.0
print('scene rescale', bds_scale)
self.z_bounds = [self.sfm.dmin * bds_scale, self.sfm.dmax * bds_scale]
self.c2w[:, :3, 3] *= bds_scale
fx = self.sfm.ref_cam['fx']
fy = self.sfm.ref_cam['fy']
width = self.sfm.ref_cam['width']
height = self.sfm.ref_cam['height']

print('z_bounds from LLFF:', self.z_bounds)


zmid = (self.z_bounds[0] + self.z_bounds[1]) * 0.5
zrad = (self.z_bounds[1] - self.z_bounds[0]) * 0.5
radx = 1 + 2 * self.sfm.offset / self.gt.size(2)
rady = 1 + 2 * self.sfm.offset / self.gt.size(1)

z_max = 1.0 # 0.5
scene_scale = z_max / zrad
zmid *= scene_scale
x_max = zmid * (width + 2 * self.sfm.offset) / (2 * fx)
y_max = zmid * (height + 2 * self.sfm.offset) / (2 * fy)

self.scene_center = [0.0, 0.0, zmid]
self.scene_radius = [x_max, y_max, z_max]
self.z_bounds = [self.z_bounds[0] * scene_scale, self.z_bounds[1] * scene_scale]
self.scene_center = [0.0, 0.0, 0.0]
self.scene_radius = [radx, rady, 1.0]
print('scene_radius', self.scene_radius)
self.use_sphere_bound = False

self.c2w[:, :3, 3] *= scene_scale

def gen_rays(self, factor=1):
print(" Generating rays, scaling factor", factor)
Expand Down Expand Up @@ -208,6 +217,13 @@ def gen_rays(self, factor=1):
dirs = dirs.view(-1, 3)
gt = gt.reshape(-1, 3)

# To NDC
origins, dirs = convert_to_ndc(
origins,
dirs,
self.ndc_coeffs)
dirs /= torch.norm(dirs, dim=-1, keepdim=True)

self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt)
self.rays = self.rays_init

Expand Down Expand Up @@ -354,6 +370,20 @@ def readLLFF(self, dataset, ref_img=""):
image_dir = os.path.join(dataset, "images")
if not os.path.exists(image_dir) and not os.path.isdir(image_dir):
return False

self.use_integral_scaling = False
scaled_img_dir = ''
scale = self.scale
if scale != 1 and abs((1.0 / scale) - round(1.0 / scale)) < 1e-9:
# Integral scaling
scaled_img_dir = "images_" + str(round(1.0 / scale))
if os.path.isdir(os.path.join(self.dataset, scaled_img_dir)):
self.use_integral_scaling = True
image_dir = os.path.join(self.dataset, scaled_img_dir)
print('Using pre-scaled images from', image_dir)
else:
scaled_img_dir = "images"

# load R,T
(
reference_depth,
Expand All @@ -373,7 +403,7 @@ def nsvf_sort_key(x):
return x

# get all image of this dataset
images_path = [os.path.join("images", f) for f in sorted(os.listdir(image_dir), key=nsvf_sort_key)]
images_path = [os.path.join(scaled_img_dir, f) for f in sorted(os.listdir(image_dir), key=nsvf_sort_key)]

# LLFF dataset has only single camera in dataset
if len(intrinsic) == 3:
Expand Down
1 change: 1 addition & 0 deletions opt/util/obj_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
# Hardcoded; adjust scene_scale to make sure the scene fits in a unit sphere
self.scene_center = [0.0, 0.0, 0.0]
self.scene_radius = 1.0
self.ndc_coeffs = (-1.0, -1.0) # disable
self.use_sphere_bound = True

def gen_rays(self, factor=1):
Expand Down
4 changes: 3 additions & 1 deletion svox2/csrc/include/data_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ struct SparseGridSpec {
Tensor links;
Tensor _offset;
Tensor _scaling;
float _z_ratio;

int basis_dim;
bool use_learned_basis;
Expand Down Expand Up @@ -46,6 +45,9 @@ struct CameraSpec {
int width;
int height;

float ndc_coeffx;
float ndc_coeffy;

inline void check() {
CHECK_INPUT(c2w);
TORCH_CHECK(c2w.is_floating_point());
Expand Down
10 changes: 6 additions & 4 deletions svox2/csrc/include/data_spec_packed.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ struct PackedSparseGridSpec {
spec._offset.data_ptr<float>()[2]},
_scaling{spec._scaling.data_ptr<float>()[0],
spec._scaling.data_ptr<float>()[1],
spec._scaling.data_ptr<float>()[2]},
_z_ratio{spec._z_ratio} {
spec._scaling.data_ptr<float>()[2]} {
}

float* __restrict__ density_data;
Expand All @@ -42,15 +41,15 @@ struct PackedSparseGridSpec {
const int basis_dim, sh_data_dim, basis_reso;
const float _offset[3];
const float _scaling[3];
const float _z_ratio;
};

struct PackedCameraSpec {
PackedCameraSpec(CameraSpec& cam) :
c2w(cam.c2w.packed_accessor32<float, 2, torch::RestrictPtrTraits>()),
fx(cam.fx), fy(cam.fy),
cx(cam.cx), cy(cam.cy),
width(cam.width), height(cam.height) {}
width(cam.width), height(cam.height),
ndc_coeffx(cam.ndc_coeffx), ndc_coeffy(cam.ndc_coeffy) {}
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits>
c2w;
float fx;
Expand All @@ -59,6 +58,9 @@ struct PackedCameraSpec {
float cy;
int width;
int height;

float ndc_coeffx;
float ndc_coeffy;
};

struct PackedRaysSpec {
Expand Down
Loading

0 comments on commit cfa0b46

Please sign in to comment.