-
Notifications
You must be signed in to change notification settings - Fork 49
/
base.py
77 lines (67 loc) · 2.79 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from ..builder import EMBEDDERS
@EMBEDDERS.register_module()
class BaseEmbedder(nn.Module):
def __init__(self,
i_embed=0,
multires=10,
multires_dirs=4,
input_ch=3,
**kwargs):
super().__init__() # 对于集成了nn.Module的类型,如果有可学习参数,必须加上这个
if i_embed == -1:
self.embed_fns, self.embed_ch = [nn.Identity()], input_ch
self.embed_fns_dirs, self.embed_ch_dirs = [nn.Identity()], input_ch
else:
self.embed_fns, self.embed_ch = self.create_embedding_fn(
multires, input_ch=input_ch)
self.embed_fns_dirs, self.embed_ch_dirs = self.create_embedding_fn(
multires_dirs, input_ch=input_ch)
def create_embedding_fn(self,
multires,
input_ch=3,
cat_input=True,
log_sampling=True,
periodic_fns=[torch.sin, torch.cos]):
num_freqs = multires
max_freq_log2 = multires - 1
embed_fns = []
out_dim = 0
d = input_ch
if cat_input:
embed_fns.append(lambda x: x)
out_dim += d
N_freqs = num_freqs
max_freq = max_freq_log2
if log_sampling:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in periodic_fns:
embed_fns.append(
lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
return embed_fns, out_dim
def get_embed_ch(self):
return self.embed_ch, self.embed_ch_dirs
def forward(self, data):
# pts shape before reshape
data['unflatten_shape'] = data['pts'].shape[:-1]
inputs, viewdirs = data['pts'], data['viewdirs']
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = self.run_embed(inputs_flat, self.embed_fns)
#如果chunk为None, inputs也是2维,不需要expand
if len(inputs.shape) > len(viewdirs.shape):
input_dirs = viewdirs[:, None].expand(inputs.shape)
else:
input_dirs = viewdirs
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = self.run_embed(input_dirs_flat, self.embed_fns_dirs)
embedded = torch.cat([embedded, embedded_dirs], -1)
data['embedded'] = embedded
return data
def run_embed(self, x, embed_fns):
return torch.cat([fn(x) for fn in embed_fns], -1)