Skip to content

Commit

Permalink
Update to use the new TF-GAN API, which:
Browse files Browse the repository at this point in the history
a) Has a simpler API
b) Uses SavedModels
c) Is faster
d) Works on TPU

PiperOrigin-RevId: 276240022
  • Loading branch information
joel-shor authored and diegolascasas committed Oct 24, 2019
1 parent e4ac909 commit ff8a6e4
Showing 1 changed file with 5 additions and 21 deletions.
26 changes: 5 additions & 21 deletions cs_gan/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
import tensorflow_gan as tfgan


def inception_preprocess_fn(images):
return tfgan.eval.preprocess_image(images * 255)


def compute_inception_score(
images, max_classifier_batch_size=16, assert_data_ranges=True):
"""Computes the classifier score, using the given model.
Expand Down Expand Up @@ -62,14 +58,9 @@ def _choose_batch_size(num_images, max_batch_size):
else:
control_deps = []

# Do the preprocessing in the fn function to avoid having to keep all the
# resized data in memory.
def classifier_fn(images):
return tfgan.eval.run_inception(inception_preprocess_fn(images))

# Inception module does resizing, if necessary.
with tf.control_dependencies(control_deps):
return tfgan.eval.classifier_score(
images, classifier_fn=classifier_fn, num_batches=num_batches)
return tfgan.eval.run_inception(images, num_batches=num_batches)


def compute_fid(
Expand Down Expand Up @@ -113,17 +104,10 @@ def _choose_batch_size(num_images, max_batch_size):
else:
control_deps = []

# Do the preprocessing in the fn function to avoid having to keep all the
# resized data in memory.
def classifier_fn(images):
return tfgan.eval.run_inception(
inception_preprocess_fn(images),
output_tensor=tfgan.eval.INCEPTION_FINAL_POOL)

# Inception module does resizing, if necessary.
with tf.control_dependencies(control_deps):
return tfgan.eval.frechet_classifier_distance(
real_images, other,
classifier_fn=classifier_fn, num_batches=num_batches)
return tfgan.eval.frechet_inception_distance(
real_images, other, num_batches=num_batches)


def generate_big_batch(generator, generator_inputs, max_num_samples=100):
Expand Down

0 comments on commit ff8a6e4

Please sign in to comment.