Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Sep 27, 2023
2 parents 12e8384 + e453961 commit e0fa293
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ def buildin_models(
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Flatten(name="E_flatten")(nn)
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="E_dense")(nn)
nn = keras.layers.Reshape([1, 1, -1])(nn) # expand_dims to 4D again for applying BatchNormalization
elif output_layer == "GAP":
"""GlobalAveragePooling2D"""
nn = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, name="GAP_batchnorm")(nn)
nn = keras.layers.GlobalAveragePooling2D(name="GAP_pool")(nn)
nn = keras.layers.GlobalAveragePooling2D(keepdims=True, name="GAP_pool")(nn)
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="GAP_dense")(nn)
Expand All @@ -154,17 +155,18 @@ def buildin_models(
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Conv2D(emb_shape, 1, use_bias=use_bias, kernel_initializer="glorot_normal", name="GDC_conv")(nn)
nn = keras.layers.Flatten(name="GDC_flatten")(nn)
# nn = keras.layers.Dense(emb_shape, activation=None, use_bias=use_bias, kernel_initializer="glorot_normal", name="GDC_dense")(nn)
elif output_layer == "F":
"""F, E without first BatchNormalization"""
if dropout > 0 and dropout < 1:
nn = keras.layers.Dropout(dropout)(nn)
nn = keras.layers.Flatten(name="F_flatten")(nn)
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="F_dense")(nn)
nn = keras.layers.Reshape([1, 1, -1])(nn) # expand_dims to 4D again for applying BatchNormalization

# `fix_gamma=True` in MXNet means `scale=False` in Keras
embedding = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, scale=scale, name="pre_embedding")(nn)
embedding = keras.layers.Flatten()(embedding)
embedding_fp32 = keras.layers.Activation("linear", dtype="float32", name="embedding")(embedding)

basic_model = keras.models.Model(inputs, embedding_fp32, name=xx.name)
Expand Down

0 comments on commit e0fa293

Please sign in to comment.