Skip to content

Commit

Permalink
updating with GAN-training notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Panos Achlioptas committed Aug 14, 2018
1 parent 968d7c5 commit a65d74d
Show file tree
Hide file tree
Showing 9 changed files with 695 additions and 957 deletions.
857 changes: 200 additions & 657 deletions notebooks/train_latent_gan.ipynb

Large diffs are not rendered by default.

343 changes: 343 additions & 0 deletions notebooks/train_raw_gan.ipynb

Large diffs are not rendered by default.

311 changes: 78 additions & 233 deletions notebooks/train_single_class_ae.ipynb

Large diffs are not rendered by default.

58 changes: 7 additions & 51 deletions src/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

from . in_out import create_dir, pickle_data, unpickle_data
from . general_utils import apply_augmentations, iterate_in_chunks
from . neural_net import Neural_Net

model_saver_id = 'models.ckpt'
from . neural_net import Neural_Net, MODEL_SAVER_ID


class Configuration():
Expand Down Expand Up @@ -103,18 +101,7 @@ def __init__(self, name, graph, configuration):
self.gt = tf.placeholder(tf.float32, out_shape)
else:
self.gt = self.x

def restore_model(self, model_path, epoch, verbose=False):
'''Restore all the variables of a saved auto-encoder model.
'''
self.saver.restore(self.sess, osp.join(model_path, model_saver_id + '-' + str(int(epoch))))

if self.epoch.eval(session=self.sess) != epoch:
warnings.warn('Loaded model\'s epoch doesn\'t match the requested one.')
else:
if verbose:
print('Model restored in epoch {0}.'.format(epoch))


def partial_fit(self, X, GT=None):
'''Trains the model with mini-batches of input data.
If GT is not None, then the reconstruction loss compares the output of the net that is fed X, with the GT.
Expand Down Expand Up @@ -143,7 +130,7 @@ def reconstruct(self, X, GT=None, compute_loss=True):
if compute_loss:
loss = self.loss
else:
loss = tf.no_op()
loss = self.no_op

if GT is None:
return self.sess.run((self.x_reconstr, loss), feed_dict={self.x: X})
Expand Down Expand Up @@ -181,7 +168,7 @@ def train(self, train_data, configuration, log_file=None, held_out_data=None):

for _ in xrange(c.training_epochs):
loss, duration = self._single_epoch_train(train_data, c)
epoch = int(self.sess.run(self.epoch.assign_add(tf.constant(1.0))))
epoch = int(self.sess.run(self.increment_epoch))
stats.append((epoch, loss, duration))

if epoch % c.loss_display_step == 0:
Expand All @@ -191,7 +178,7 @@ def train(self, train_data, configuration, log_file=None, held_out_data=None):

# Save the models checkpoint periodically.
if c.saver_step is not None and (epoch % c.saver_step == 0 or epoch - 1 == 0):
checkpoint_path = osp.join(c.train_dir, model_saver_id)
checkpoint_path = osp.join(c.train_dir, MODEL_SAVER_ID)
self.saver.save(self.sess, checkpoint_path, global_step=self.epoch)

if c.exists_and_is_not_none('summary_step') and (epoch % c.summary_step == 0 or epoch - 1 == 0):
Expand Down Expand Up @@ -236,37 +223,7 @@ def evaluate(self, in_data, configuration, ret_pre_augmentation=False):
return reconstructions, data_loss, np.squeeze(feed_data), ids, np.squeeze(original_data), pre_aug
else:
return reconstructions, data_loss, np.squeeze(feed_data), ids, np.squeeze(original_data)

def evaluate_one_by_one(self, in_data, configuration):
'''Evaluates every data point separately to recover the loss on it. Thus, the batch_size = 1 making it
a slower than the 'evaluate' method.
'''

if self.is_denoising:
original_data, ids, feed_data = in_data.full_epoch_data(shuffle=False)
if feed_data is None:
feed_data = original_data
feed_data = apply_augmentations(feed_data, configuration) # This is a new copy of the batch.
else:
original_data, ids, _ = in_data.full_epoch_data(shuffle=False)
feed_data = apply_augmentations(original_data, configuration)

n_examples = in_data.num_examples
assert(len(original_data) == n_examples)

feed_data = np.expand_dims(feed_data, 1)
original_data = np.expand_dims(original_data, 1)
reconstructions = np.zeros([n_examples] + self.n_output)
losses = np.zeros([n_examples])

for i in xrange(n_examples):
if self.is_denoising:
reconstructions[i], losses[i] = self.reconstruct(feed_data[i], original_data[i])
else:
reconstructions[i], losses[i] = self.reconstruct(feed_data[i])

return reconstructions, losses, np.squeeze(feed_data), ids, np.squeeze(original_data)


def embedding_at_tensor(self, dataset, conf, feed_original=True, apply_augmentation=False, tensor_name='bottleneck'):
'''
Observation: the NN-neighborhoods seem more reasonable when we do not apply the augmentation.
Expand Down Expand Up @@ -299,8 +256,7 @@ def embedding_at_tensor(self, dataset, conf, feed_original=True, apply_augmentat

embedding = np.vstack(embedding)
return feed, embedding, ids



def get_latent_codes(self, pclouds, batch_size=100):
''' Convenience wrapper of self.transform to get the latent (bottle-neck) codes for a set of input point
clouds.
Expand Down
40 changes: 40 additions & 0 deletions src/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np
from numpy.linalg import norm
import matplotlib.pylab as plt
from mpl_toolkits.mplot3d import Axes3D


def rand_rotation_matrix(deflection=1.0, seed=None):
Expand Down Expand Up @@ -104,3 +106,41 @@ def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
grid = grid[norm(grid, axis=1) <= 0.5]

return grid, spacing

def plot_3d_point_cloud(x, y, z, show=True, show_axis=True, in_u_sphere=False, marker='.', s=8, alpha=.8, figsize=(5, 5), elev=10, azim=240, axis=None, title=None, *args, **kwargs):

if axis is None:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')
else:
ax = axis
fig = axis

if title is not None:
plt.title(title)

sc = ax.scatter(x, y, z, marker=marker, s=s, alpha=alpha, *args, **kwargs)
ax.view_init(elev=elev, azim=azim)

if in_u_sphere:
ax.set_xlim3d(-0.5, 0.5)
ax.set_ylim3d(-0.5, 0.5)
ax.set_zlim3d(-0.5, 0.5)
else:
miv = 0.7 * np.min([np.min(x), np.min(y), np.min(z)]) # Multiply with 0.7 to squeeze free-space.
mav = 0.7 * np.max([np.max(x), np.max(y), np.max(z)])
ax.set_xlim(miv, mav)
ax.set_ylim(miv, mav)
ax.set_zlim(miv, mav)
plt.tight_layout()

if not show_axis:
plt.axis('off')

if 'c' in kwargs:
plt.colorbar(sc)

if show:
plt.show()

return fig
21 changes: 11 additions & 10 deletions src/generators_discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . tf_utils import expand_scope_by_name


def mlp_discriminator(in_signal, non_linearity=tf.nn.relu, reuse=False, scope=None, b_norm=[True], dropout_prob=None):
def mlp_discriminator(in_signal, non_linearity=tf.nn.relu, reuse=False, scope=None, b_norm=True, dropout_prob=None):
''' used in nips submission.
'''
encoder_args = {'n_filters': [64, 128, 256, 256, 512], 'filter_sizes': [1, 1, 1, 1, 1], 'strides': [1, 1, 1, 1, 1]}
Expand All @@ -32,7 +32,7 @@ def mlp_discriminator(in_signal, non_linearity=tf.nn.relu, reuse=False, scope=No
return d_prob, d_logit


def point_cloud_generator(z, pc_dims, layer_sizes=[64, 128, 512, 1024], non_linearity=tf.nn.relu, b_norm=[False], b_norm_last=False, dropout_prob=None):
def point_cloud_generator(z, pc_dims, layer_sizes=[64, 128, 512, 1024], non_linearity=tf.nn.relu, b_norm=False, b_norm_last=False, dropout_prob=None):
''' used in nips submission.
'''

Expand All @@ -42,6 +42,7 @@ def point_cloud_generator(z, pc_dims, layer_sizes=[64, 128, 512, 1024], non_line

out_signal = decoder_with_fc_only(z, layer_sizes=layer_sizes, non_linearity=non_linearity, b_norm=b_norm)
out_signal = non_linearity(out_signal)

if dropout_prob is not None:
out_signal = dropout(out_signal, dropout_prob)

Expand Down Expand Up @@ -70,33 +71,33 @@ def convolutional_discriminator(in_signal, non_linearity=tf.nn.relu,
return d_prob, d_logit


def latent_code_generator(z, out_dim, layer_sizes=[64, 128], b_norm=[False]):
def latent_code_generator(z, out_dim, layer_sizes=[64, 128], b_norm=False):
layer_sizes = layer_sizes + out_dim
out_signal = decoder_with_fc_only(z, layer_sizes=layer_sizes, b_norm=b_norm)
out_signal = tf.nn.relu(out_signal)
return out_signal


def latent_code_discriminator(in_singnal, layer_sizes=[64, 128, 256, 256, 512], b_norm=[False], non_linearity=tf.nn.relu, reuse=False, scope=None):
def latent_code_discriminator(in_singnal, layer_sizes=[64, 128, 256, 256, 512], b_norm=False, non_linearity=tf.nn.relu, reuse=False, scope=None):
layer_sizes = layer_sizes + [1]
d_logit = decoder_with_fc_only(in_singnal, layer_sizes=layer_sizes, non_linearity=non_linearity, b_norm=b_norm, reuse=reuse, scope=scope)
d_prob = tf.nn.sigmoid(d_logit)
return d_prob, d_logit


def latent_code_discriminator_two_layers(in_singnal, layer_sizes=[256, 512], b_norm=[False], non_linearity=tf.nn.relu, reuse=False, scope=None):
''' used in nips submission.
def latent_code_discriminator_two_layers(in_signal, layer_sizes=[256, 512], b_norm=False, non_linearity=tf.nn.relu, reuse=False, scope=None):
''' Used in ICML submission.
'''
layer_sizes = layer_sizes + [1]
d_logit = decoder_with_fc_only(in_singnal, layer_sizes=layer_sizes, non_linearity=non_linearity, b_norm=b_norm, reuse=reuse, scope=scope)
d_logit = decoder_with_fc_only(in_signal, layer_sizes=layer_sizes, non_linearity=non_linearity, b_norm=b_norm, reuse=reuse, scope=scope)
d_prob = tf.nn.sigmoid(d_logit)
return d_prob, d_logit


def latent_code_generator_two_layers(z, out_dim, layer_sizes=[128], b_norm=[False]):
''' used in nips submission.
def latent_code_generator_two_layers(z, out_dim, layer_sizes=[128], b_norm=False):
''' Used in ICML submission.
'''
layer_sizes = layer_sizes + out_dim
out_signal = decoder_with_fc_only(z, layer_sizes=layer_sizes, b_norm=b_norm)
out_signal = tf.nn.relu(out_signal) # I could have added batch-norm before relu here.
out_signal = tf.nn.relu(out_signal)
return out_signal
14 changes: 14 additions & 0 deletions src/neural_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@ def __init__(self, name, graph):
with tf.variable_scope(name):
with tf.device('/cpu:0'):
self.epoch = tf.get_variable('epoch', [], initializer=tf.constant_initializer(0), trainable=False)
self.increment_epoch = self.epoch.assign_add(tf.constant(1.0))

self.no_op = tf.no_op()

def is_training(self):
is_training_op = self.graph.get_collection('is_training')
return self.sess.run(is_training_op)[0]

def restore_model(self, model_path, epoch, verbose=False):
'''Restore all the variables of a saved model.
'''
self.saver.restore(self.sess, osp.join(model_path, MODEL_SAVER_ID + '-' + str(int(epoch))))

if self.epoch.eval(session=self.sess) != epoch:
warnings.warn('Loaded model\'s epoch doesn\'t match the requested one.')
else:
if verbose:
print('Model restored in epoch {0}.'.format(epoch))
6 changes: 1 addition & 5 deletions src/vanilla_gan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
Created on Apr 27, 2017
Created on 2018
Author: Achlioptas Panos (Github ID: optas)
'''
Expand Down Expand Up @@ -57,10 +57,6 @@ def generator_noise_distribution(self, n_samples, ndims, mu, sigma):
return np.random.normal(mu, sigma, (n_samples, ndims))

def _single_epoch_train(self, train_data, batch_size, noise_params):
'''
see: http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/
http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/
'''
n_examples = train_data.num_examples
epoch_loss_d = 0.
epoch_loss_g = 0.
Expand Down
2 changes: 1 addition & 1 deletion src/w_gan_gp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
Created on May 22, 2017
Created on May 22, 2018
Author: Achlioptas Panos (Github ID: optas)
'''
Expand Down

0 comments on commit a65d74d

Please sign in to comment.