Skip to content

Commit

Permalink
Use bibtex format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 281754451
  • Loading branch information
Cyprien de Masson d'Autume authored and diegolascasas committed Nov 21, 2019
1 parent db2393d commit d705acd
Show file tree
Hide file tree
Showing 10 changed files with 1,442 additions and 0 deletions.
61 changes: 61 additions & 0 deletions scratchgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ScratchGAN

This is the example code for the following NeurIPS 2019 paper. If you use the
code here please cite this paper:

@article{DBLP:journals/corr/abs-1905-09922,
author = {Cyprien de Masson d'Autume and
Mihaela Rosca and
Jack W. Rae and
Shakir Mohamed},
title = {Training language GANs from Scratch},
journal = {CoRR},
volume = {abs/1905.09922},
year = {2019},
url = {http://arxiv.org/abs/1905.09922},
archivePrefix = {arXiv},
eprint = {1905.09922},
timestamp = {Wed, 29 May 2019 11:27:50 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1905-09922},
bibsource = {dblp computer science bibliography, https://dblp.org}
}


## Contents

The code contains:

* `generators.py`: implementation of the generator.
* `discriminator_nets.py`: implementation of the discriminator.
* `eval_metrics.py`: implementation of the FED metric.
* `losses.py`: implementation of the RL loss for the generator.
* `reader.py`: data reader / tokenizer.
* `experiment.py`: main training script.

The data contains:

* `{train,valid,test}.json`: the EMNLP2017 News dataset.
* `glove_emnlp2017.txt`: the relevant subset of GloVe embeddings.

## Running

Place the data files in the directory specified by `data_dir` flag.

Create and activate a virtual environment if needed:

virtualenv scratchgan-venv
source scratchgan-venv/bin/activate

Install requirements:

pip install -r scratchgan/requirements.txt

Run training and evaluation jobs:

python2 scratchgan.experiment.py --mode="train" &
python2 scratchgan.experiment.py --mode="evaluate_pair" &

The evaluation code is designed to run in parallel with the training.

The training code saves checkpoints periodically, the evaluation code
looks for new checkpoints and evaluate them.
13 changes: 13 additions & 0 deletions scratchgan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
121 changes: 121 additions & 0 deletions scratchgan/discriminator_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discriminator networks for text data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sonnet as snt
import tensorflow as tf
from scratchgan import utils


class LSTMEmbedDiscNet(snt.AbstractModule):
"""An LSTM discriminator that operates on word indexes."""

def __init__(self,
feature_sizes,
vocab_size,
use_layer_norm,
trainable_embedding_size,
dropout,
pad_token,
embedding_source=None,
vocab_file=None,
name='LSTMEmbedDiscNet'):
super(LSTMEmbedDiscNet, self).__init__(name=name)
self._feature_sizes = feature_sizes
self._vocab_size = vocab_size
self._use_layer_norm = use_layer_norm
self._trainable_embedding_size = trainable_embedding_size
self._embedding_source = embedding_source
self._vocab_file = vocab_file
self._dropout = dropout
self._pad_token = pad_token
if self._embedding_source:
assert vocab_file

def _build(self, sequence, sequence_length, is_training=True):
"""Connect to the graph.
Args:
sequence: A [batch_size, max_sequence_length] tensor of int. For example
the indices of words as sampled by the generator.
sequence_length: A [batch_size] tensor of int. Length of the sequence.
is_training: Boolean, False to disable dropout.
Returns:
A [batch_size, max_sequence_length, feature_size] tensor of floats. For
each sequence in the batch, the features should (hopefully) allow to
distinguish if the value at each timestep is real or generated.
"""
batch_size, max_sequence_length = sequence.shape.as_list()
keep_prob = (1.0 - self._dropout) if is_training else 1.0

if self._embedding_source:
all_embeddings = utils.make_partially_trainable_embeddings(
self._vocab_file, self._embedding_source, self._vocab_size,
self._trainable_embedding_size)
else:
all_embeddings = tf.get_variable(
'trainable_embedding',
shape=[self._vocab_size, self._trainable_embedding_size],
trainable=True)
_, self._embedding_size = all_embeddings.shape.as_list()
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob)
embeddings = tf.nn.embedding_lookup(input_embeddings, sequence)
embeddings.shape.assert_is_compatible_with(
[batch_size, max_sequence_length, self._embedding_size])
position_dim = 8
embeddings_pos = utils.append_position_signal(embeddings, position_dim)
embeddings_pos = tf.reshape(
embeddings_pos,
[batch_size * max_sequence_length, self._embedding_size + position_dim])
lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos)
lstm_inputs = tf.reshape(
lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]])
lstm_inputs.shape.assert_is_compatible_with(
[batch_size, max_sequence_length, self._feature_sizes[0]])

encoder_cells = []
for feature_size in self._feature_sizes:
encoder_cells += [
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
]
encoder_cell = snt.DeepRNN(encoder_cells)
initial_state = encoder_cell.initial_state(batch_size)

hidden_states, _ = tf.nn.dynamic_rnn(
cell=encoder_cell,
inputs=lstm_inputs,
sequence_length=sequence_length,
initial_state=initial_state,
swap_memory=True)

hidden_states.shape.assert_is_compatible_with(
[batch_size, max_sequence_length,
sum(self._feature_sizes)])
logits = snt.BatchApply(snt.Linear(1))(hidden_states)
logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1])
logits_flat = tf.reshape(logits, [batch_size, max_sequence_length])

# Mask past first PAD symbol
#
# Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking
# into account the sequence_length properly, because otherwise
# the logits at a given timestep will depend on the inputs for all other
# timesteps, including the ones that should be masked.
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
masked_logits_flat = logits_flat * tf.cast(mask, tf.float32)
return masked_logits_flat
48 changes: 48 additions & 0 deletions scratchgan/eval_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub


def fid(generated_sentences, real_sentences):
"""Compute FID rn sentences using pretrained universal sentence encoder.
Args:
generated_sentences: list of N strings.
real_sentences: list of N strings.
Returns:
Frechet distance between activations.
"""
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
real_embed = embed(real_sentences)
generated_embed = embed(generated_sentences)
distance = tfgan.eval.frechet_classifier_distance_from_activations(
real_embed, generated_embed)

# Restrict the thread pool size to prevent excessive CPU usage.
config = tf.ConfigProto()
config.intra_op_parallelism_threads = 16
config.inter_op_parallelism_threads = 16
with tf.Session(config=config) as session:
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
distance_np = session.run(distance)
return distance_np
Loading

0 comments on commit d705acd

Please sign in to comment.