-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathatrous_conv.py
29 lines (19 loc) · 1.03 KB
/
atrous_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# coding: utf-8
import chainer.links as L
from spectral_norms import SNConv, SNHookConv
class AtrousConv(L.DilatedConvolution2D):
def __init__(self, in_channels, out_channels, ksize=None, rate=1, initialW=None):
super().__init__(in_channels, out_channels, ksize=ksize, stride=1, pad=rate, dilate=rate, initialW=initialW)
class AtrousSNConv(SNConv):
def __init__(self, in_channels, out_channels, ksize=None, rate=1, initialW=None):
super().__init__(in_channels, out_channels, ksize=ksize, stride=1, pad=rate, dilate=rate, initialW=initialW)
class AtrousSNHookConv(SNHookConv):
def __init__(self, in_channels, out_channels, ksize=None, rate=1, initialW=None):
super().__init__(in_channels, out_channels, ksize=ksize, stride=1, pad=rate, dilate=rate, initialW=initialW)
def define_atrous_conv(opt):
if opt.conv_norm == 'original':
return AtrousConv
if opt.conv_norm == 'spectral_norm':
return AtrousSNConv
if opt.conv_norm == 'spectral_norm_hook':
return AtrousSNHookConv