|
1 | 1 | import torch.nn as nn |
2 | | -from models.model_utils import conv2d_bn_block, dense_layer_bn |
| 2 | +from models.model_utils import conv2d_bn_block, dense_layer_bn, Identity |
3 | 3 |
|
4 | 4 |
|
5 | 5 | class NormalNet2D(nn.Module): |
@@ -27,7 +27,7 @@ def __init__(self, n_channels=1, nlabels=1, init_filters=32): |
27 | 27 | ) |
28 | 28 | self.classifier = nn.Sequential( |
29 | 29 | dense_layer_bn(16*nf, 16*nf), |
30 | | - dense_layer_bn(16*nf, nlabels, activation=(lambda x: x)) |
| 30 | + dense_layer_bn(16*nf, nlabels, activation=Identity) |
31 | 31 | ) |
32 | 32 |
|
33 | 33 | def forward(self, x): |
@@ -61,7 +61,7 @@ def __init__(self, n_channels=1, nlabels=1, init_filters=32): |
61 | 61 | ) |
62 | 62 | self.classifier = nn.Sequential( |
63 | 63 | nn.AvgPool2d(2), |
64 | | - dense_layer_bn(16*nf, nlabels, activation=lambda x: x) |
| 64 | + dense_layer_bn(16*nf, nlabels, activation=Identity) |
65 | 65 | ) |
66 | 66 |
|
67 | 67 | def forward(self, x): |
@@ -98,7 +98,7 @@ def __init__(self, n_channels=1, nlabels=1, init_filters=32): |
98 | 98 | ) |
99 | 99 | self.classifier = nn.Sequential( |
100 | 100 | dense_layer_bn(16*nf, 16*nf), |
101 | | - dense_layer_bn(16*nf, nlabels, activation=lambda x: x) |
| 101 | + dense_layer_bn(16*nf, nlabels, activation=Identity) |
102 | 102 | ) |
103 | 103 |
|
104 | 104 | def forward(self, x): |
@@ -130,7 +130,7 @@ def __init__(self, n_channels=1, nlabels=1, init_filters=32): |
130 | 130 | ) |
131 | 131 | self.classifier = nn.Sequential( |
132 | 132 | nn.AvgPool2d(2), |
133 | | - dense_layer_bn(16*nf, nlabels, activation=lambda x: x) |
| 133 | + dense_layer_bn(16*nf, nlabels, activation=Identity) |
134 | 134 | ) |
135 | 135 |
|
136 | 136 | def forward(self, x): |
|
0 commit comments