forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
10 changed files
with
1,442 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# ScratchGAN | ||
|
||
This is the example code for the following NeurIPS 2019 paper. If you use the | ||
code here please cite this paper: | ||
|
||
@article{DBLP:journals/corr/abs-1905-09922, | ||
author = {Cyprien de Masson d'Autume and | ||
Mihaela Rosca and | ||
Jack W. Rae and | ||
Shakir Mohamed}, | ||
title = {Training language GANs from Scratch}, | ||
journal = {CoRR}, | ||
volume = {abs/1905.09922}, | ||
year = {2019}, | ||
url = {http://arxiv.org/abs/1905.09922}, | ||
archivePrefix = {arXiv}, | ||
eprint = {1905.09922}, | ||
timestamp = {Wed, 29 May 2019 11:27:50 +0200}, | ||
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1905-09922}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
|
||
|
||
## Contents | ||
|
||
The code contains: | ||
|
||
* `generators.py`: implementation of the generator. | ||
* `discriminator_nets.py`: implementation of the discriminator. | ||
* `eval_metrics.py`: implementation of the FED metric. | ||
* `losses.py`: implementation of the RL loss for the generator. | ||
* `reader.py`: data reader / tokenizer. | ||
* `experiment.py`: main training script. | ||
|
||
The data contains: | ||
|
||
* `{train,valid,test}.json`: the EMNLP2017 News dataset. | ||
* `glove_emnlp2017.txt`: the relevant subset of GloVe embeddings. | ||
|
||
## Running | ||
|
||
Place the data files in the directory specified by `data_dir` flag. | ||
|
||
Create and activate a virtual environment if needed: | ||
|
||
virtualenv scratchgan-venv | ||
source scratchgan-venv/bin/activate | ||
|
||
Install requirements: | ||
|
||
pip install -r scratchgan/requirements.txt | ||
|
||
Run training and evaluation jobs: | ||
|
||
python2 scratchgan.experiment.py --mode="train" & | ||
python2 scratchgan.experiment.py --mode="evaluate_pair" & | ||
|
||
The evaluation code is designed to run in parallel with the training. | ||
|
||
The training code saves checkpoints periodically, the evaluation code | ||
looks for new checkpoints and evaluate them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2019 DeepMind Technologies Limited and Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2019 DeepMind Technologies Limited and Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Discriminator networks for text data.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import sonnet as snt | ||
import tensorflow as tf | ||
from scratchgan import utils | ||
|
||
|
||
class LSTMEmbedDiscNet(snt.AbstractModule): | ||
"""An LSTM discriminator that operates on word indexes.""" | ||
|
||
def __init__(self, | ||
feature_sizes, | ||
vocab_size, | ||
use_layer_norm, | ||
trainable_embedding_size, | ||
dropout, | ||
pad_token, | ||
embedding_source=None, | ||
vocab_file=None, | ||
name='LSTMEmbedDiscNet'): | ||
super(LSTMEmbedDiscNet, self).__init__(name=name) | ||
self._feature_sizes = feature_sizes | ||
self._vocab_size = vocab_size | ||
self._use_layer_norm = use_layer_norm | ||
self._trainable_embedding_size = trainable_embedding_size | ||
self._embedding_source = embedding_source | ||
self._vocab_file = vocab_file | ||
self._dropout = dropout | ||
self._pad_token = pad_token | ||
if self._embedding_source: | ||
assert vocab_file | ||
|
||
def _build(self, sequence, sequence_length, is_training=True): | ||
"""Connect to the graph. | ||
Args: | ||
sequence: A [batch_size, max_sequence_length] tensor of int. For example | ||
the indices of words as sampled by the generator. | ||
sequence_length: A [batch_size] tensor of int. Length of the sequence. | ||
is_training: Boolean, False to disable dropout. | ||
Returns: | ||
A [batch_size, max_sequence_length, feature_size] tensor of floats. For | ||
each sequence in the batch, the features should (hopefully) allow to | ||
distinguish if the value at each timestep is real or generated. | ||
""" | ||
batch_size, max_sequence_length = sequence.shape.as_list() | ||
keep_prob = (1.0 - self._dropout) if is_training else 1.0 | ||
|
||
if self._embedding_source: | ||
all_embeddings = utils.make_partially_trainable_embeddings( | ||
self._vocab_file, self._embedding_source, self._vocab_size, | ||
self._trainable_embedding_size) | ||
else: | ||
all_embeddings = tf.get_variable( | ||
'trainable_embedding', | ||
shape=[self._vocab_size, self._trainable_embedding_size], | ||
trainable=True) | ||
_, self._embedding_size = all_embeddings.shape.as_list() | ||
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob) | ||
embeddings = tf.nn.embedding_lookup(input_embeddings, sequence) | ||
embeddings.shape.assert_is_compatible_with( | ||
[batch_size, max_sequence_length, self._embedding_size]) | ||
position_dim = 8 | ||
embeddings_pos = utils.append_position_signal(embeddings, position_dim) | ||
embeddings_pos = tf.reshape( | ||
embeddings_pos, | ||
[batch_size * max_sequence_length, self._embedding_size + position_dim]) | ||
lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos) | ||
lstm_inputs = tf.reshape( | ||
lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]]) | ||
lstm_inputs.shape.assert_is_compatible_with( | ||
[batch_size, max_sequence_length, self._feature_sizes[0]]) | ||
|
||
encoder_cells = [] | ||
for feature_size in self._feature_sizes: | ||
encoder_cells += [ | ||
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm) | ||
] | ||
encoder_cell = snt.DeepRNN(encoder_cells) | ||
initial_state = encoder_cell.initial_state(batch_size) | ||
|
||
hidden_states, _ = tf.nn.dynamic_rnn( | ||
cell=encoder_cell, | ||
inputs=lstm_inputs, | ||
sequence_length=sequence_length, | ||
initial_state=initial_state, | ||
swap_memory=True) | ||
|
||
hidden_states.shape.assert_is_compatible_with( | ||
[batch_size, max_sequence_length, | ||
sum(self._feature_sizes)]) | ||
logits = snt.BatchApply(snt.Linear(1))(hidden_states) | ||
logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1]) | ||
logits_flat = tf.reshape(logits, [batch_size, max_sequence_length]) | ||
|
||
# Mask past first PAD symbol | ||
# | ||
# Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking | ||
# into account the sequence_length properly, because otherwise | ||
# the logits at a given timestep will depend on the inputs for all other | ||
# timesteps, including the ones that should be masked. | ||
mask = utils.get_mask_past_symbol(sequence, self._pad_token) | ||
masked_logits_flat = logits_flat * tf.cast(mask, tf.float32) | ||
return masked_logits_flat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2019 DeepMind Technologies Limited and Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Evaluation metrics.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
import tensorflow_gan as tfgan | ||
import tensorflow_hub as hub | ||
|
||
|
||
def fid(generated_sentences, real_sentences): | ||
"""Compute FID rn sentences using pretrained universal sentence encoder. | ||
Args: | ||
generated_sentences: list of N strings. | ||
real_sentences: list of N strings. | ||
Returns: | ||
Frechet distance between activations. | ||
""" | ||
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2") | ||
real_embed = embed(real_sentences) | ||
generated_embed = embed(generated_sentences) | ||
distance = tfgan.eval.frechet_classifier_distance_from_activations( | ||
real_embed, generated_embed) | ||
|
||
# Restrict the thread pool size to prevent excessive CPU usage. | ||
config = tf.ConfigProto() | ||
config.intra_op_parallelism_threads = 16 | ||
config.inter_op_parallelism_threads = 16 | ||
with tf.Session(config=config) as session: | ||
session.run(tf.global_variables_initializer()) | ||
session.run(tf.tables_initializer()) | ||
distance_np = session.run(distance) | ||
return distance_np |
Oops, something went wrong.