Skip to content

Commit

Permalink
Add symmetric log1p to tfr.utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 392758538
  • Loading branch information
HongleiZhuang authored and ramakumar1729 committed Oct 18, 2021
1 parent e6ed9a9 commit 65fd6e7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow_ranking/python/keras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ def log2_inverse(rank):
@tf.keras.utils.register_keras_serializable(package="tensorflow_ranking")
def is_greater_equal_1(label):
return tf.greater_equal(label, 1.0)


@tf.keras.utils.register_keras_serializable(package="tensorflow_ranking")
def symmetric_log1p(t):
return tf.math.log1p(t * tf.sign(t)) * tf.sign(t)

2 changes: 2 additions & 0 deletions tensorflow_ranking/python/keras/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_functions_are_serializable(self):
utils.pow_minus_1,
utils.log2_inverse,
utils.is_greater_equal_1,
utils.symmetric_log1p,
]:
self.assertIsNotNone(tf.keras.utils.serialize_keras_object(fn))

Expand All @@ -36,6 +37,7 @@ def test_functions_are_callable(self):
self.assertEqual(utils.pow_minus_1(1.0), 1.0)
self.assertEqual(utils.log2_inverse(1.0), 1.0)
self.assertEqual(utils.is_greater_equal_1(1.0), True)
self.assertAllClose(utils.symmetric_log1p(-1.0), -0.69314718056)


if __name__ == '__main__':
Expand Down

0 comments on commit 65fd6e7

Please sign in to comment.