From ff8a6e4c31a8648cee4f2158d239bf92787cc25e Mon Sep 17 00:00:00 2001 From: Joel Shor Date: Wed, 23 Oct 2019 10:58:49 +0100 Subject: [PATCH] Update to use the new TF-GAN API, which: a) Has a simpler API b) Uses SavedModels c) Is faster d) Works on TPU PiperOrigin-RevId: 276240022 --- cs_gan/image_metrics.py | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/cs_gan/image_metrics.py b/cs_gan/image_metrics.py index f174869b..34ba3e10 100644 --- a/cs_gan/image_metrics.py +++ b/cs_gan/image_metrics.py @@ -20,10 +20,6 @@ import tensorflow_gan as tfgan -def inception_preprocess_fn(images): - return tfgan.eval.preprocess_image(images * 255) - - def compute_inception_score( images, max_classifier_batch_size=16, assert_data_ranges=True): """Computes the classifier score, using the given model. @@ -62,14 +58,9 @@ def _choose_batch_size(num_images, max_batch_size): else: control_deps = [] - # Do the preprocessing in the fn function to avoid having to keep all the - # resized data in memory. - def classifier_fn(images): - return tfgan.eval.run_inception(inception_preprocess_fn(images)) - + # Inception module does resizing, if necessary. with tf.control_dependencies(control_deps): - return tfgan.eval.classifier_score( - images, classifier_fn=classifier_fn, num_batches=num_batches) + return tfgan.eval.run_inception(images, num_batches=num_batches) def compute_fid( @@ -113,17 +104,10 @@ def _choose_batch_size(num_images, max_batch_size): else: control_deps = [] - # Do the preprocessing in the fn function to avoid having to keep all the - # resized data in memory. - def classifier_fn(images): - return tfgan.eval.run_inception( - inception_preprocess_fn(images), - output_tensor=tfgan.eval.INCEPTION_FINAL_POOL) - + # Inception module does resizing, if necessary. with tf.control_dependencies(control_deps): - return tfgan.eval.frechet_classifier_distance( - real_images, other, - classifier_fn=classifier_fn, num_batches=num_batches) + return tfgan.eval.frechet_inception_distance( + real_images, other, num_batches=num_batches) def generate_big_batch(generator, generator_inputs, max_num_samples=100):