diff --git a/models.py b/models.py index cd50d8e..342d51f 100644 --- a/models.py +++ b/models.py @@ -139,10 +139,11 @@ def buildin_models( nn = keras.layers.Dropout(dropout)(nn) nn = keras.layers.Flatten(name="E_flatten")(nn) nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="E_dense")(nn) + nn = keras.layers.Reshape([1, 1, -1])(nn) # expand_dims to 4D again for applying BatchNormalization elif output_layer == "GAP": """GlobalAveragePooling2D""" nn = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, name="GAP_batchnorm")(nn) - nn = keras.layers.GlobalAveragePooling2D(name="GAP_pool")(nn) + nn = keras.layers.GlobalAveragePooling2D(keepdims=True, name="GAP_pool")(nn) if dropout > 0 and dropout < 1: nn = keras.layers.Dropout(dropout)(nn) nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="GAP_dense")(nn) @@ -154,7 +155,6 @@ def buildin_models( if dropout > 0 and dropout < 1: nn = keras.layers.Dropout(dropout)(nn) nn = keras.layers.Conv2D(emb_shape, 1, use_bias=use_bias, kernel_initializer="glorot_normal", name="GDC_conv")(nn) - nn = keras.layers.Flatten(name="GDC_flatten")(nn) # nn = keras.layers.Dense(emb_shape, activation=None, use_bias=use_bias, kernel_initializer="glorot_normal", name="GDC_dense")(nn) elif output_layer == "F": """F, E without first BatchNormalization""" @@ -162,9 +162,11 @@ def buildin_models( nn = keras.layers.Dropout(dropout)(nn) nn = keras.layers.Flatten(name="F_flatten")(nn) nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="F_dense")(nn) + nn = keras.layers.Reshape([1, 1, -1])(nn) # expand_dims to 4D again for applying BatchNormalization # `fix_gamma=True` in MXNet means `scale=False` in Keras embedding = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, scale=scale, name="pre_embedding")(nn) + embedding = keras.layers.Flatten()(embedding) embedding_fp32 = keras.layers.Activation("linear", dtype="float32", name="embedding")(embedding) basic_model = keras.models.Model(inputs, embedding_fp32, name=xx.name)