Skip to content

Commit

Permalink
Move to PY3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 313217111
  • Loading branch information
Cyprien de Masson d'Autume authored and diegolascasas committed May 27, 2020
1 parent 77886cb commit b105a36
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 22 deletions.
27 changes: 14 additions & 13 deletions scratchgan/experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -109,7 +110,7 @@ def train(config):
raw_data = reader.get_raw_data(
data_path=config.data_dir, dataset=config.dataset)
train_data, valid_data, word_to_id = raw_data
id_to_word = {v: k for k, v in word_to_id.iteritems()}
id_to_word = {v: k for k, v in word_to_id.items()}
vocab_size = len(word_to_id)
max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset]
logging.info("Vocabulary size: %d", vocab_size)
Expand All @@ -124,8 +125,8 @@ def train(config):
name="real_sequence")
real_sequence_length = tf.placeholder(
dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length")
first_batch_np = iterator.next()
valid_batch_np = iterator_valid.next()
first_batch_np = next(iterator)
valid_batch_np = next(iterator_valid)

test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()}
test_fake_batch = {
Expand All @@ -146,7 +147,7 @@ def train(config):
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
vocab_file = "/tmp/vocab.txt"
with gfile.GFile(vocab_file, "w") as f:
for i in xrange(len(id_to_word)):
for i in range(len(id_to_word)):
f.write(id_to_word[i] + "\n")
logging.info("Temporary vocab file: %s", vocab_file)
else:
Expand Down Expand Up @@ -263,17 +264,17 @@ def train(config):
if latest_ckpt:
saver.restore(sess, latest_ckpt)

for step in xrange(config.num_steps):
real_data_np = iterator.next()
for step in range(config.num_steps):
real_data_np = next(iterator)
train_feed = {
real_sequence: real_data_np["sequence"],
real_sequence_length: real_data_np["sequence_length"],
}

# Update generator and discriminator.
for _ in xrange(config.num_disc_updates):
for _ in range(config.num_disc_updates):
sess.run(disc_update, feed_dict=train_feed)
for _ in xrange(config.num_gen_updates):
for _ in range(config.num_gen_updates):
sess.run(gen_update, feed_dict=train_feed)

# Reporting
Expand Down Expand Up @@ -325,7 +326,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
# Build graph.
train_data, valid_data, word_to_id = reader.get_raw_data(
data_dir, dataset=dataset)
id_to_word = {v: k for k, v in word_to_id.iteritems()}
id_to_word = {v: k for k, v in word_to_id.items()}
vocab_size = len(word_to_id)
train_iterator = reader.iterator(raw_data=train_data, batch_size=batch_size)
valid_iterator = reader.iterator(raw_data=valid_data, batch_size=batch_size)
Expand Down Expand Up @@ -353,7 +354,7 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
embedding_source = utils.get_embedding_path(config.data_dir, config.dataset)
vocab_file = "/tmp/vocab.txt"
with gfile.GFile(vocab_file, "w") as f:
for i in xrange(len(id_to_word)):
for i in range(len(id_to_word)):
f.write(id_to_word[i] + "\n")
logging.info("Temporary vocab file: %s", vocab_file)
else:
Expand Down Expand Up @@ -431,10 +432,10 @@ def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset,
logging.info("Restoring variables.")
saver.restore(sess, checkpoint_path)

for i in xrange(num_batches):
for i in range(num_batches):
logging.info("Batch %d / %d", i, num_batches)
train_data_np = train_iterator.next()
valid_data_np = valid_iterator.next()
train_data_np = next(train_iterator)
valid_data_np = next(valid_iterator)
feed_dict = {
train_sequence: train_data_np["sequence"],
train_sequence_length: train_data_np["sequence_length"],
Expand Down
3 changes: 2 additions & 1 deletion scratchgan/generators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -104,7 +105,7 @@ def _build(self, is_training=True, temperature=1.0):
sample = tf.tile(
tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
logging.info('Unrolling over %d steps.', max_sequence_length)
for _ in xrange(max_sequence_length):
for _ in range(max_sequence_length):
# Input is sampled word at t-1.
embedding = tf.nn.embedding_lookup(input_embeddings, sample)
embedding.shape.assert_is_compatible_with(
Expand Down
5 changes: 3 additions & 2 deletions scratchgan/losses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -67,9 +68,9 @@ def reinforce_loss(disc_logits, gen_logprobs, gamma, decay):
# Compute cumulative rewards.
rewards_list = tf.unstack(rewards, axis=1)
cumulative_rewards = []
for t in xrange(sequence_length):
for t in range(sequence_length):
cum_value = tf.zeros(shape=[batch_size])
for s in xrange(t, sequence_length):
for s in range(t, sequence_length):
cum_value += np.power(gamma, (s - t)) * rewards_list[s]
cumulative_rewards.append(cum_value)
cumulative_rewards = tf.stack(cumulative_rewards, axis=1)
Expand Down
11 changes: 6 additions & 5 deletions scratchgan/reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -56,17 +57,17 @@ def _build_vocab(json_data):
title_tokens = tokenize(title)
vocab.update(title_tokens)
# Most common words first.
count_pairs = sorted(vocab.items(), key=lambda x: (-x[1], x[0]))
words, _ = zip(*count_pairs)
count_pairs = sorted(list(vocab.items()), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
words = list(words)
if UNK not in words:
words = [UNK] + words
word_to_id = dict(zip(words, range(len(words))))
word_to_id = dict(list(zip(words, list(range(len(words))))))

# Tokens are now sorted by frequency. There's no guarantee that `PAD` will
# end up at `PAD_INT` index. Enforce it by swapping whatever token is
# currently at the `PAD_INT` index with the `PAD` token.
word = word_to_id.keys()[word_to_id.values().index(PAD_INT)]
word = list(word_to_id.keys())[list(word_to_id.values()).index(PAD_INT)]
word_to_id[PAD], word_to_id[word] = word_to_id[word], word_to_id[PAD]
assert word_to_id[PAD] == PAD_INT

Expand Down Expand Up @@ -120,7 +121,7 @@ def get_raw_data(data_path, dataset, truncate_vocab=20000):
"""
if dataset not in FILENAMES:
raise ValueError("Invalid dataset {}. Valid datasets: {}".format(
dataset, FILENAMES.keys()))
dataset, list(FILENAMES.keys())))
train_file, valid_file, _ = FILENAMES[dataset]

train_path = os.path.join(data_path, train_file)
Expand Down
9 changes: 8 additions & 1 deletion scratchgan/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -34,7 +35,7 @@ def _get_embedding_initializer(vocab_file, embedding_source, vocab_size):
embedding_lines = f.readlines()

# First line contains embedding dim.
_, embedding_dim = map(int, embedding_lines[0].split())
_, embedding_dim = list(map(int, embedding_lines[0].split()))
# Get the tokens as strings.
tokens = [line.split()[0] for line in embedding_lines[1:]]
# Get the actual embedding matrix.
Expand Down Expand Up @@ -275,6 +276,12 @@ def make_partially_trainable_embeddings(vocab_file, embedding_source,
vectors for word representation. In Proceedings of the 2014 conference on
empirical methods in natural language processing (EMNLP) (pp. 1532-1543).
Args:
vocab_file: vocabulary file.
embedding_source: path to the actual embeddings.
vocab_size: number of words in vocabulary.
trainable_embedding_size: size of the trainable part of the embeddings.
Returns:
A matrix of partially pretrained embeddings.
"""
Expand Down

0 comments on commit b105a36

Please sign in to comment.