-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbase_mlp.py
124 lines (96 loc) · 3.99 KB
/
base_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import numpy as np
class PosEncoder(torch.nn.Module):
def __init__(self, freq_num, freq_factor=np.pi):
super().__init__()
self.freq_num = freq_num
self.freq_factor = freq_factor
def forward(self, x):
freq_multiplier = (
self.freq_factor * 2 ** torch.arange(self.freq_num, device=x.device)
)
x_expand = x.unsqueeze(-1)
sin_val = torch.sin(x_expand * freq_multiplier)
cos_val = torch.cos(x_expand * freq_multiplier)
return torch.cat(
[x_expand, sin_val, cos_val], -1
).view(*x.shape[:-1], -1)
class MLP(torch.nn.Module):
def __init__(self, feature_nums):
super().__init__()
self.input_size = feature_nums[0]
self.output_size = feature_nums[-1]
self.layers = torch.nn.ModuleList()
self.layers.append(torch.nn.Linear(feature_nums[0], feature_nums[1]))
for i in range(1, len(feature_nums)-1):
self.layers.append(torch.nn.PReLU(feature_nums[i]))
self.layers.append(torch.nn.Linear(feature_nums[i], feature_nums[i+1]))
def forward(self, x):
input_shape = x.shape
x = x.view(-1, self.input_size)
for l in self.layers:
x = l(x)
return x.view(*input_shape[:-1], self.output_size)
class ResBlock(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.prelu_0 = torch.nn.PReLU(input_size)
self.fc_0 = torch.nn.Linear(input_size, hidden_size)
self.prelu_1 = torch.nn.PReLU(hidden_size)
self.fc_1 = torch.nn.Linear(hidden_size, output_size)
self.shortcut = (
torch.nn.Linear(input_size, output_size, bias=False)
if input_size != output_size else None)
def forward(self, x):
residual = self.fc_1(self.prelu_1(self.fc_0(self.prelu_0(x))))
shortcut = x if self.shortcut is None else self.shortcut(x)
return residual + shortcut
class ResBlocks(torch.nn.Module):
def __init__(self, input_size, hidden_size, block_num, output_size):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.input_layer = torch.nn.Linear(input_size, hidden_size)
self.blocks = torch.nn.ModuleList([
ResBlock(hidden_size, hidden_size, hidden_size if i < block_num - 1 else output_size)
for i in range(block_num)
])
def forward(self, x):
input_shape = x.shape
x = self.input_layer(x.view(-1, self.input_size))
for block in self.blocks:
x = block(x)
return x.view(*input_shape[:-1], self.output_size)
class PosEncodeResnet(torch.nn.Module):
def __init__(
self,
posencode_size, nonencode_size, hidden_size, output_size,
posencode_freq_num, block_num
):
super().__init__()
self.input_size = (
posencode_size * (2 * posencode_freq_num + 1)
+ nonencode_size
)
self.output_size = output_size
if posencode_size > 0:
self.pos_encoder = PosEncoder(posencode_freq_num)
self.input_layer = torch.nn.Linear(self.input_size, hidden_size)
self.blocks = torch.nn.ModuleList(
[ResnetBlock(hidden_size, hidden_size, hidden_size)
for i in range(block_num)]
)
self.output_prelu = torch.nn.PReLU(hidden_size)
self.output_layer = torch.nn.Linear(hidden_size, output_size)
def forward(self, posencode_x, nonencode_x):
x = []
if posencode_x is not None:
x.append(self.pos_encoder(posencode_x))
if nonencode_x is not None:
x.append(nonencode_x)
x = torch.cat(x, axis=-1)
input_shape = x.shape
x = self.input_layer(x.view(-1, self.input_size))
for block in self.blocks:
x = block(x)
return self.output_layer(self.output_prelu(x)).view(*input_shape[:-1], self.output_size)