Skip to content

Commit

Permalink
minimal updates
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Jun 24, 2022
1 parent ce1aef9 commit 5c0810d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def batchnorm_with_activation(inputs, activation="relu", zero_gamma=False, name=
)(inputs)
if activation:
act_name = name + activation
if activation == "PReLU":
if activation.lower() == "prelu":
nn = keras.layers.PReLU(shared_axes=[1, 2], alpha_initializer=tf.initializers.Constant(0.25), name=act_name)(nn)
else:
nn = keras.layers.Activation(activation=activation, name=act_name)(nn)
Expand Down
6 changes: 3 additions & 3 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def buildin_models(
nn = keras.layers.Dense(emb_shape, use_bias=use_bias, kernel_initializer="glorot_normal", name="F_dense")(nn)

# `fix_gamma=True` in MXNet means `scale=False` in Keras
embedding = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, scale=scale, dtype="float32", name="pre_embedding")(nn)
embedding = keras.layers.BatchNormalization(momentum=bn_momentum, epsilon=bn_epsilon, scale=scale, name="pre_embedding")(nn)
embedding_fp32 = keras.layers.Activation("linear", dtype="float32", name="embedding")(embedding)

basic_model = keras.models.Model(inputs, embedding_fp32, name=xx.name)
Expand Down Expand Up @@ -210,8 +210,8 @@ def call(self, inputs, **kwargs):
self.w = tf.gather(self.sub_weights, self.cur_id)
self.cur_id.assign((self.cur_id + 1) % self.partial_fc_split)

norm_w = K.l2_normalize(self.w, axis=0)
norm_inputs = K.l2_normalize(inputs, axis=1)
norm_w = tf.nn.l2_normalize(self.w, axis=0, epsilon=1e-5)
norm_inputs = tf.nn.l2_normalize(inputs, axis=1, epsilon=1e-5)
output = K.dot(norm_inputs, norm_w)
if self.loss_top_k > 1:
output = K.reshape(output, (-1, self.units, self.loss_top_k))
Expand Down

0 comments on commit 5c0810d

Please sign in to comment.