Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 405428984
  • Loading branch information
TensorFlow Ranking authored and HongleiZhuang committed Oct 25, 2021
1 parent f4af081 commit 672841f
Show file tree
Hide file tree
Showing 4 changed files with 500 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tensorflow_ranking/python/keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def __init__(self,
num_layers: int = 1,
dropout: float = 0.5,
name: Optional[str] = None,
input_noise_stddev: Optional[float] = None,
**kwargs: Dict[Any, Any]):
"""Initializes the layer.
Expand All @@ -434,13 +435,15 @@ def __init__(self,
num_layers: Number of cross-document attention layers.
dropout: Dropout probability.
name: Name of the layer.
input_noise_stddev: Input Gaussian noise standard deviation.
**kwargs: keyword arguments.
"""
super().__init__(name=name, **kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._num_layers = num_layers
self._dropout = dropout
self._input_noise_stddev = input_noise_stddev

def build(self, input_shape: tf.TensorShape):
"""Build method to create weights and sub-layers.
Expand All @@ -463,6 +466,10 @@ def build(self, input_shape: tf.TensorShape):
# recursively for `num_layers` times.
# Shape: [batch_size, list_size, feature_dims] ->
# [batch_size, list_size, head_size].
self._input_noise = None
if self._input_noise_stddev:
self._input_noise = tf.keras.layers.GaussianNoise(
self._input_noise_stddev)
self._input_projection = tf.keras.layers.Dense(
units=self._head_size, activation='relu')
self._input_projection.build(example_inputs_shape)
Expand Down Expand Up @@ -513,6 +520,8 @@ def call(self, inputs: tf.Tensor, training: bool = True) -> tf.Tensor:
if list_mask is None:
list_mask = tf.ones(shape=(batch_size, list_size), dtype=tf.bool)
x = self._input_projection(example_inputs, training=training)
if self._input_noise:
x = self._input_noise(x, training=training)

list_mask = tf.cast(list_mask, dtype=tf.int32)
attention_mask = nlp_modeling_layers.SelfAttentionMask()(
Expand All @@ -534,6 +543,7 @@ def get_config(self):
'head_size': self._head_size,
'num_layers': self._num_layers,
'dropout': self._dropout,
'input_noise_stddev': self._input_noise_stddev,
})
return config

Expand Down
52 changes: 52 additions & 0 deletions tensorflow_ranking/research/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Description:
# TensorFlow Ranking research code for published papers.

package(
default_visibility = [
"//tensorflow_ranking:__subpackages__",
],
)

licenses(["notice"])

py_library(
name = "dasalc_lib",
srcs = ["dasalc.py"],
srcs_version = "PY3",
deps = [
# py/absl/flags dep,
# py/tensorflow dep,
"//tensorflow_ranking",
],
)

py_binary(
name = "dasalc_py_binary",
srcs = ["dasalc.py"],
main = "dasalc.py",
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dasalc_lib",
],
)

py_test(
name = "dasalc_test",
size = "large",
srcs = ["dasalc_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_pip",
"notsan",
],
deps = [
":dasalc_lib",
# py/absl/flags dep,
# py/absl/testing:flagsaver dep,
# py/absl/testing:parameterized dep,
# py/tensorflow dep,
# tensorflow_serving/apis:input_proto_py_pb2 dep,
],
)
Loading

0 comments on commit 672841f

Please sign in to comment.