-
Notifications
You must be signed in to change notification settings - Fork 2
/
baseline_unet.py
100 lines (76 loc) · 3.29 KB
/
baseline_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# image size: 150 * 300 ====>>> 75 x 150
# image rad delta: 180 deg / 128 = 1.4 deg
# pool1 rad delta: 180 deg / 64 = 2.8 deg
# pool2 rad delta: 180 deg / 32 = 5.6 deg
# ---------------------- arch ------------------------ #
# 128 * 256
# conv(3, 64, pi/32, (8, 16)) -> bn -> relu -> pool(2) -> conv(64, 128, pi/16, (4, 8)) -> bn -> relu -> pool(2) ->
#
# conv(128, 256, pi/4, (4, 8)) -> bn -> relu -> pool(2) -> conv(256, 256, pi, (4, 8)) -> bn -> relu ->
#
# conv(256 + 256, 128, pi/4, (4, 8)) -> bn -> relu -> conv(128 + 128, 64, pi/16, (4, 8)) -> bn -> relu ->
#
# conv(64 + 64, 1, pi/32, (8, 16)) -> bn -> relu -> conv(1, 1, (5, 5))
from sconv.module import SphericalConv, SphereMSE, SphericalPooling
from torch import nn
import numpy as np
import torch as th
from torch.autograd import Variable
from opts import opts
opt = opts().parse()
class SphericalUNet(nn.Module):
def __init__(self):
super(SphericalUNet, self).__init__()
self.conv1 = SphericalConv(4, 64, np.pi/32, kernel_size=(8, 16), kernel_sr=None)
self.conv2 = SphericalConv(64, 128, np.pi/16, kernel_size=(4, 8), kernel_sr=None)
self.conv3 = SphericalConv(128, 256, np.pi / 4, kernel_size=(4, 8), kernel_sr=None)
self.conv4 = SphericalConv(256, 256, np.pi, kernel_size=(8, 16), kernel_sr=None)
self.conv5 = SphericalConv(256 + 256, 128, np.pi / 4, kernel_size=(4, 8), kernel_sr=None)
self.conv6 = SphericalConv(128 + 128, 64, np.pi / 16, kernel_size=(4, 8), kernel_sr=None)
self.conv7 = SphericalConv(64 + 64, 1, np.pi / 32, kernel_size=(4, 8), kernel_sr=None)
self.pool1 = SphericalPooling()
self.pool2 = SphericalPooling()
self.pool3 = SphericalPooling()
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=True)
self.bn2 = nn.BatchNorm2d(128)
self.relu2 = nn.ReLU(inplace=True)
self.bn3 = nn.BatchNorm2d(256)
self.relu3 = nn.ReLU(inplace=True)
self.bn4 = nn.BatchNorm2d(256)
self.relu4 = nn.ReLU(inplace=True)
self.up1 = nn.Upsample(scale_factor=2)
self.bn5 = nn.BatchNorm2d(128)
self.relu5 = nn.ReLU(inplace=True)
self.up2 = nn.Upsample(scale_factor=2)
self.bn6 = nn.BatchNorm2d(64)
self.relu6 = nn.ReLU(inplace=True)
self.up3 = nn.Upsample(scale_factor=2)
def forward(self, image, last):
x = th.cat([image, last], dim=1)
c1 = self.conv1(x)
b1 = self.bn1(c1)
r1 = self.relu1(b1)
p1 = self.pool1(r1)
c2 = self.conv2(p1)
b2 = self.bn2(c2)
r2 = self.relu2(b2)
p2 = self.pool2(r2)
c3 = self.conv3(p2)
b3 = self.bn3(c3)
r3 = self.relu3(b3)
p3 = self.pool3(r3)
c4 = self.conv4(p3)
b4 = self.bn4(c4)
r4 = self.relu4(b4)
r4u = self.up1(r4)
c5 = self.conv5(th.cat([r4u, r3], dim=1))
b5 = self.bn5(c5)
r5 = self.relu5(b5)
r5u = self.up2(r5)
c6 = self.conv6(th.cat([r5u, r2], dim=1))
b6 = self.bn6(c6)
r6 = self.relu6(b6)
r6u = self.up3(r6)
c7 = self.conv7(th.cat([r6u, r1], dim=1))
return c7