From ec3367515f4db0cb923ad13091f4926ae1f3c557 Mon Sep 17 00:00:00 2001 From: Panos Date: Sun, 26 Nov 2017 18:35:33 -0800 Subject: [PATCH] pushing basic code --- .gitignore | 4 +- src/ae_templates.py | 77 ++++++++++ src/autoencoder.py | 304 +++++++++++++++++++++++++++++++++++++++ src/encoders_decoders.py | 243 +++++++++++++++++++++++++++++++ src/gan.py | 44 ++++++ src/latent_gan.py | 99 +++++++++++++ src/neural_net.py | 55 +++++++ src/point_net_ae.py | 166 +++++++++++++++++++++ src/raw_gan.py | 120 ++++++++++++++++ 9 files changed, 1111 insertions(+), 1 deletion(-) create mode 100755 src/ae_templates.py create mode 100755 src/autoencoder.py create mode 100755 src/encoders_decoders.py create mode 100755 src/gan.py create mode 100755 src/latent_gan.py create mode 100755 src/neural_net.py create mode 100755 src/point_net_ae.py create mode 100755 src/raw_gan.py diff --git a/.gitignore b/.gitignore index 653cd02..c681c34 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ +.project .ipynb_checkpoints .DS_Store +.pydevproject *.pyc -*.nfs* \ No newline at end of file +*.nfs* diff --git a/src/ae_templates.py b/src/ae_templates.py new file mode 100755 index 0000000..65e0f57 --- /dev/null +++ b/src/ae_templates.py @@ -0,0 +1,77 @@ +''' +Created on September 2, 2017 + +@author: optas +''' +import numpy as np + +from . encoders_decoders import encoder_with_convs_and_symmetry, decoder_with_fc_only, encoder_with_convs_and_symmetry_new + + +def mlp_architecture_ala_iclr_18(n_pc_points, bneck_size, bneck_post_mlp=False): + ''' Single class experiments. + ''' + if n_pc_points != 2048: + raise ValueError() + + encoder = encoder_with_convs_and_symmetry_new + decoder = decoder_with_fc_only + + n_input = [n_pc_points, 3] + + encoder_args = {'n_filters': [64, 128, 128, 256, bneck_size], + 'filter_sizes': [1], + 'strides': [1], + 'b_norm': True, + 'verbose': True + } + + decoder_args = {'layer_sizes': [256, 256, np.prod(n_input)], + 'b_norm': False, + 'b_norm_finish': False, + 'verbose': True + } + + if bneck_post_mlp: + encoder_args['n_filters'].pop() + decoder_args['layer_sizes'][0] = bneck_size + + return encoder, decoder, encoder_args, decoder_args + + +def conv_architecture_ala_nips_17(n_pc_points): + if n_pc_points == 2048: + encoder_args = {'n_filters': [128, 128, 256, 512], + 'filter_sizes': [40, 20, 10, 10], + 'strides': [1, 2, 2, 1] + } + else: + assert(False) + + n_input = [n_pc_points, 3] + + decoder_args = {'layer_sizes': [1024, 2048, np.prod(n_input)]} + + res = {'encoder': encoder_with_convs_and_symmetry, + 'decoder': decoder_with_fc_only, + 'encoder_args': encoder_args, + 'decoder_args': decoder_args + } + return res + + +def default_train_params(single_class=True): + params = {'batch_size': 50, + 'training_epochs': 500, + 'denoising': False, + 'learning_rate': 0.0005, + 'z_rotate': False, + 'saver_step': 10, + 'loss_display_step': 1 + } + + if not single_class: + params['z_rotate'] = True + params['training_epochs'] = 1000 + + return params diff --git a/src/autoencoder.py b/src/autoencoder.py new file mode 100755 index 0000000..e449c2a --- /dev/null +++ b/src/autoencoder.py @@ -0,0 +1,304 @@ +''' +Created on February 2, 2017 + +@author: optas +''' + +import warnings +import os.path as osp +import tensorflow as tf +import numpy as np + +from tflearn import is_training + +from general_tools.in_out.basics import create_dir, pickle_data, unpickle_data +from general_tools.simpletons import iterate_in_chunks + +from . in_out import apply_augmentations +from .. neural_net import Neural_Net + +model_saver_id = 'models.ckpt' + + +class Configuration(): + def __init__(self, n_input, encoder, decoder, encoder_args={}, decoder_args={}, + training_epochs=200, batch_size=10, learning_rate=0.001, denoising=False, + saver_step=None, train_dir=None, z_rotate=False, loss='l2', gauss_augment=None, + saver_max_to_keep=None, loss_display_step=1, spatial_trans=False, debug=False, + n_z=None, n_output=None, latent_vs_recon=1.0, consistent_io=None): + + # Parameters for any AE + self.n_input = n_input + self.is_denoising = denoising + self.loss = loss.lower() + self.decoder = decoder + self.encoder = encoder + self.encoder_args = encoder_args + self.decoder_args = decoder_args + + # Training related parameters + self.batch_size = batch_size + self.learning_rate = learning_rate + self.loss_display_step = loss_display_step + self.saver_step = saver_step + self.train_dir = train_dir + self.gauss_augment = gauss_augment + self.z_rotate = z_rotate + self.saver_max_to_keep = saver_max_to_keep + self.training_epochs = training_epochs + self.debug = debug + + # Used in VAE + self.latent_vs_recon = np.array([latent_vs_recon], dtype=np.float32)[0] + self.n_z = n_z + + # Used in AP + if n_output is None: + self.n_output = n_input + else: + self.n_output = n_output + + # Fancy - TODO factor seperetaly. + self.consistent_io = consistent_io + + def exists_and_is_not_none(self, attribute): + return hasattr(self, attribute) and getattr(self, attribute) is not None + + def __str__(self): + keys = self.__dict__.keys() + vals = self.__dict__.values() + index = np.argsort(keys) + res = '' + for i in index: + if callable(vals[i]): + v = vals[i].__name__ + else: + v = str(vals[i]) + res += '%30s: %s\n' % (str(keys[i]), v) + return res + + def save(self, file_name): + pickle_data(file_name + '.pickle', self) + with open(file_name + '.txt', 'w') as fout: + fout.write(self.__str__()) + + @staticmethod + def load(file_name): + return unpickle_data(file_name + '.pickle').next() + + +class AutoEncoder(Neural_Net): + '''Basis class for a Neural Network that implements an Auto-Encoder in TensorFlow. + ''' + + def __init__(self, name, graph, configuration): + Neural_Net.__init__(self, name, graph) + self.is_denoising = configuration.is_denoising + self.n_input = configuration.n_input + self.n_output = configuration.n_output # TODO Re-factor for AP + + in_shape = [None] + self.n_input + out_shape = [None] + self.n_output + + with tf.variable_scope(name): + self.x = tf.placeholder(tf.float32, in_shape) + if self.is_denoising: + self.gt = tf.placeholder(tf.float32, out_shape) + else: + self.gt = self.x + + def restore_model(self, model_path, epoch, verbose=False): + '''Restore all the variables of a saved auto-encoder model. + ''' + self.saver.restore(self.sess, osp.join(model_path, model_saver_id + '-' + str(int(epoch)))) + + if self.epoch.eval(session=self.sess) != epoch: + warnings.warn('Loaded model\'s epoch doesn\'t match the requested one.') + else: + if verbose: + print('Model restored in epoch {0}.'.format(epoch)) + + def partial_fit(self, X, GT=None): + '''Trains the model with mini-batches of input data. + If GT is not None, then the reconstruction loss compares the output of the net that is fed X, with the GT. + This can be useful when training for instance a denoising auto-encoder. + Returns: + The loss of the mini-batch. + The reconstructed (output) point-clouds. + ''' + is_training(True, session=self.sess) + try: + if GT is not None: + _, loss, recon = self.sess.run((self.train_step, self.loss, self.x_reconstr), feed_dict={self.x: X, self.gt: GT}) + else: + _, loss, recon = self.sess.run((self.train_step, self.loss, self.x_reconstr), feed_dict={self.x: X}) + + is_training(False, session=self.sess) + except Exception: + raise + finally: + is_training(False, session=self.sess) + return recon, loss + + def reconstruct(self, X, GT=None, compute_loss=True): + '''Use AE to reconstruct given data. + GT will be used to measure the loss (e.g., if X is a noisy version of the GT)''' + if compute_loss: + loss = self.loss + else: + loss = tf.no_op() + + if GT is None: + return self.sess.run((self.x_reconstr, loss), feed_dict={self.x: X}) + else: + return self.sess.run((self.x_reconstr, loss), feed_dict={self.x: X, self.gt: GT}) + + def transform(self, X): + '''Transform data by mapping it into the latent space.''' + return self.sess.run(self.z, feed_dict={self.x: X}) + + def interpolate(self, x, y, steps): + ''' Interpolate between and x and y input vectors in latent space. + x, y np.arrays of size (n_points, dim_embedding). + ''' + in_feed = np.vstack((x, y)) + z1, z2 = self.transform(in_feed.reshape([2] + self.n_input)) + all_z = np.zeros((steps + 2, len(z1))) + + for i, alpha in enumerate(np.linspace(0, 1, steps + 2)): + all_z[i, :] = (alpha * z2) + ((1.0 - alpha) * z1) + + return self.sess.run((self.x_reconstr), {self.z: all_z}) + + def decode(self, z): + if np.ndim(z) == 1: # single example + z = np.expand_dims(z, 0) + return self.sess.run((self.x_reconstr), {self.z: z}) + + def train(self, train_data, configuration, log_file=None, held_out_data=None): + c = configuration + stats = [] + + if c.saver_step is not None: + create_dir(c.train_dir) + + for _ in xrange(c.training_epochs): + loss, duration = self._single_epoch_train(train_data, c) + epoch = int(self.sess.run(self.epoch.assign_add(tf.constant(1.0)))) + stats.append((epoch, loss, duration)) + + if epoch % c.loss_display_step == 0: + print("Epoch:", '%04d' % (epoch), 'training time (minutes)=', "{:.4f}".format(duration / 60.0), "loss=", "{:.9f}".format(loss)) + if log_file is not None: + log_file.write('%04d\t%.9f\t%.4f\n' % (epoch, loss, duration / 60.0)) + + # Save the models checkpoint periodically. + if c.saver_step is not None and (epoch % c.saver_step == 0 or epoch - 1 == 0): + checkpoint_path = osp.join(c.train_dir, model_saver_id) + self.saver.save(self.sess, checkpoint_path, global_step=self.epoch) + + if c.exists_and_is_not_none('summary_step') and (epoch % c.summary_step == 0 or epoch - 1 == 0): + summary = self.sess.run(self.merged_summaries) + self.train_writer.add_summary(summary, epoch) + + if held_out_data is not None and c.exists_and_is_not_none('held_out_step') and (epoch % c.held_out_step == 0): + loss, duration = self._single_epoch_train(held_out_data, c, only_fw=True) + print("Held Out Data :", 'forward time (minutes)=', "{:.4f}".format(duration / 60.0), "loss=", "{:.9f}".format(loss)) + if log_file is not None: + log_file.write('On Held_Out: %04d\t%.9f\t%.4f\n' % (epoch, loss, duration / 60.0)) + return stats + + def evaluate(self, in_data, configuration, ret_pre_augmentation=False): + n_examples = in_data.num_examples + data_loss = 0. + pre_aug = None + if self.is_denoising: + original_data, ids, feed_data = in_data.full_epoch_data(shuffle=False) + if ret_pre_augmentation: + pre_aug = feed_data.copy() + if feed_data is None: + feed_data = original_data + feed_data = apply_augmentations(feed_data, configuration) # This is a new copy of the batch. + else: + original_data, ids, _ = in_data.full_epoch_data(shuffle=False) + feed_data = apply_augmentations(original_data, configuration) + + b = configuration.batch_size + reconstructions = np.zeros([n_examples] + self.n_output) + for i in xrange(0, n_examples, b): + if self.is_denoising: + reconstructions[i:i + b], loss = self.reconstruct(feed_data[i:i + b], original_data[i:i + b]) + else: + reconstructions[i:i + b], loss = self.reconstruct(feed_data[i:i + b]) + + # Compute average loss + data_loss += (loss * len(reconstructions[i:i + b])) + data_loss /= float(n_examples) + + if pre_aug is not None: + return reconstructions, data_loss, np.squeeze(feed_data), ids, np.squeeze(original_data), pre_aug + else: + return reconstructions, data_loss, np.squeeze(feed_data), ids, np.squeeze(original_data) + + def evaluate_one_by_one(self, in_data, configuration): + '''Evaluates every data point separately to recover the loss on it. Thus, the batch_size = 1 making it + a slower than the 'evaluate' method. + ''' + + if self.is_denoising: + original_data, ids, feed_data = in_data.full_epoch_data(shuffle=False) + if feed_data is None: + feed_data = original_data + feed_data = apply_augmentations(feed_data, configuration) # This is a new copy of the batch. + else: + original_data, ids, _ = in_data.full_epoch_data(shuffle=False) + feed_data = apply_augmentations(original_data, configuration) + + n_examples = in_data.num_examples + assert(len(original_data) == n_examples) + + feed_data = np.expand_dims(feed_data, 1) + original_data = np.expand_dims(original_data, 1) + reconstructions = np.zeros([n_examples] + self.n_output) + losses = np.zeros([n_examples]) + + for i in xrange(n_examples): + if self.is_denoising: + reconstructions[i], losses[i] = self.reconstruct(feed_data[i], original_data[i]) + else: + reconstructions[i], losses[i] = self.reconstruct(feed_data[i]) + + return reconstructions, losses, np.squeeze(feed_data), ids, np.squeeze(original_data) + + def embedding_at_tensor(self, dataset, conf, feed_original=True, apply_augmentation=False, tensor_name='bottleneck'): + ''' + Observation: the NN-neighborhoods seem more reasonable when we do not apply the augmentation. + Observation: the next layer after latent (z) might be something interesting. + tensor_name: e.g. model.name + '_1/decoder_fc_0/BiasAdd:0' + ''' + batch_size = conf.batch_size + original, ids, noise = dataset.full_epoch_data(shuffle=False) + + if feed_original: + feed = original + else: + feed = noise + if feed is None: + feed = original + + feed_data = feed + if apply_augmentation: + feed_data = apply_augmentations(feed, conf) + + embedding = [] + if tensor_name == 'bottleneck': + for b in iterate_in_chunks(feed_data, batch_size): + embedding.append(self.transform(b.reshape([len(b)] + conf.n_input))) + else: + embedding_tensor = self.graph.get_tensor_by_name(tensor_name) + for b in iterate_in_chunks(feed_data, batch_size): + codes = self.sess.run(embedding_tensor, feed_dict={self.x: b.reshape([len(b)] + conf.n_input)}) + embedding.append(codes) + + embedding = np.vstack(embedding) + return feed, embedding, ids diff --git a/src/encoders_decoders.py b/src/encoders_decoders.py new file mode 100755 index 0000000..b99d625 --- /dev/null +++ b/src/encoders_decoders.py @@ -0,0 +1,243 @@ +''' +Created on February 4, 2017 + +@author: optas + +''' + +import tensorflow as tf +import numpy as np +import warnings + +from tflearn.layers.core import fully_connected, dropout +from tflearn.layers.conv import conv_1d, avg_pool_1d, highway_conv_1d +from tflearn.layers.normalization import batch_normalization +from tflearn.layers.core import fully_connected, dropout + +from .. fundamentals.utils import expand_scope_by_name, replicate_parameter_for_all_layers + + +def encoder_with_convs_and_symmetry_new(in_signal, n_filters=[64, 128, 256, 1024], filter_sizes=[1], strides=[1], + b_norm=True, non_linearity=tf.nn.relu, regularizer=None, weight_decay=0.001, + symmetry=tf.reduce_max, dropout_prob=None, pool=avg_pool_1d, pool_sizes=None, scope=None, + reuse=False, padding='same', verbose=False, closing=None, conv_op=conv_1d): + '''An Encoder (recognition network), which maps inputs onto a latent space. + ''' + + if verbose: + print 'Building Encoder' + + n_layers = len(n_filters) + filter_sizes = replicate_parameter_for_all_layers(filter_sizes, n_layers) + strides = replicate_parameter_for_all_layers(strides, n_layers) + dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers) + + if n_layers < 2: + raise ValueError('More than 1 layers are expected.') + + for i in xrange(n_layers): + if i == 0: + layer = in_signal + + name = 'encoder_conv_layer_' + str(i) + scope_i = expand_scope_by_name(scope, name) + layer = conv_op(layer, nb_filter=n_filters[i], filter_size=filter_sizes[i], strides=strides[i], regularizer=regularizer, + weight_decay=weight_decay, name=name, reuse=reuse, scope=scope_i, padding=padding) + + if verbose: + print name, 'conv params = ', np.prod(layer.W.get_shape().as_list()) + np.prod(layer.b.get_shape().as_list()), + + if b_norm: + name += '_bnorm' + scope_i = expand_scope_by_name(scope, name) + layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) + if verbose: + print 'bnorm params = ', np.prod(layer.beta.get_shape().as_list()) + np.prod(layer.gamma.get_shape().as_list()) + + if non_linearity is not None: + layer = non_linearity(layer) + + if pool is not None and pool_sizes is not None: + if pool_sizes[i] is not None: + layer = pool(layer, kernel_size=pool_sizes[i]) + + if dropout_prob is not None and dropout_prob[i] > 0: + layer = dropout(layer, 1.0 - dropout_prob[i]) + + if verbose: + print layer + print 'output size:', np.prod(layer.get_shape().as_list()[1:]), '\n' + + if symmetry is not None: + layer = symmetry(layer, axis=1) + if verbose: + print layer + + if closing is not None: + layer = closing(layer) + print layer + + return layer + + +def encoder_with_convs_and_symmetry(in_signal, n_filters=[64, 128, 256, 1024], filter_sizes=[1], strides=[1], + b_norm=True, spn=False, non_linearity=tf.nn.relu, regularizer=None, weight_decay=0.001, + symmetry=tf.reduce_max, dropout_prob=None, scope=None, reuse=False): + + '''An Encoder (recognition network), which maps inputs onto a latent space. + ''' + warnings.warn('Using old architecture.') + n_layers = len(n_filters) + filter_sizes = replicate_parameter_for_all_layers(filter_sizes, n_layers) + strides = replicate_parameter_for_all_layers(strides, n_layers) + dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers) + + if n_layers < 2: + raise ValueError('More than 1 layers are expected.') + + name = 'encoder_conv_layer_0' + scope_i = expand_scope_by_name(scope, name) + layer = conv_1d(in_signal, nb_filter=n_filters[0], filter_size=filter_sizes[0], strides=strides[0], regularizer=regularizer, weight_decay=weight_decay, name=name, reuse=reuse, scope=scope_i) + + if b_norm: + name += '_bnorm' + scope_i = expand_scope_by_name(scope, name) + layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) + + layer = non_linearity(layer) + + if dropout_prob is not None and dropout_prob[0] > 0: + layer = dropout(layer, 1.0 - dropout_prob[0]) + + for i in xrange(1, n_layers): + name = 'encoder_conv_layer_' + str(i) + scope_i = expand_scope_by_name(scope, name) + layer = conv_1d(layer, nb_filter=n_filters[i], filter_size=filter_sizes[i], strides=strides[i], regularizer=regularizer, weight_decay=weight_decay, name=name, reuse=reuse, scope=scope_i) + + if b_norm: + name += '_bnorm' + #scope_i = expand_scope_by_name(scope, name) # FORGOT TO PUT IT BEFORE ICLR + layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) + + layer = non_linearity(layer) + + if dropout_prob is not None and dropout_prob[i] > 0: + layer = dropout(layer, 1.0 - dropout_prob[i]) + + if symmetry is not None: + layer = symmetry(layer, axis=1) + + return layer + + +def decoder_with_fc_only(latent_signal, layer_sizes=[], b_norm=True, non_linearity=tf.nn.relu, + regularizer=None, weight_decay=0.001, reuse=False, scope=None, dropout_prob=None, + b_norm_finish=False, verbose=False): + '''A decoding network which maps points from the latent space back onto the data space. + ''' + if verbose: + print 'Building Decoder' + + n_layers = len(layer_sizes) + dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers) + + if n_layers < 2: + raise ValueError('For an FC decoder with single a layer use simpler code.') + + for i in xrange(0, n_layers - 1): + name = 'decoder_fc_' + str(i) + scope_i = expand_scope_by_name(scope, name) + + if i == 0: + layer = latent_signal + + layer = fully_connected(layer, layer_sizes[i], activation='linear', weights_init='xavier', name=name, regularizer=regularizer, weight_decay=weight_decay, reuse=reuse, scope=scope_i) + + if verbose: + print name, 'FC params = ', np.prod(layer.W.get_shape().as_list()) + np.prod(layer.b.get_shape().as_list()), + + if b_norm: + name += '_bnorm' + scope_i = expand_scope_by_name(scope, name) + layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) + if verbose: + print 'bnorm params = ', np.prod(layer.beta.get_shape().as_list()) + np.prod(layer.gamma.get_shape().as_list()) + + if non_linearity is not None: + layer = non_linearity(layer) + + if dropout_prob is not None and dropout_prob[i] > 0: + layer = dropout(layer, 1.0 - dropout_prob[i]) + + if verbose: + print layer + print 'output size:', np.prod(layer.get_shape().as_list()[1:]), '\n' + + # Last decoding layer never has a non-linearity. + name = 'decoder_fc_' + str(n_layers - 1) + scope_i = expand_scope_by_name(scope, name) + layer = fully_connected(layer, layer_sizes[n_layers - 1], activation='linear', weights_init='xavier', name=name, regularizer=regularizer, weight_decay=weight_decay, reuse=reuse, scope=scope_i) + if verbose: + print name, 'FC params = ', np.prod(layer.W.get_shape().as_list()) + np.prod(layer.b.get_shape().as_list()), + + if b_norm_finish: + name += '_bnorm' + scope_i = expand_scope_by_name(scope, name) + layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) + if verbose: + print 'bnorm params = ', np.prod(layer.beta.get_shape().as_list()) + np.prod(layer.gamma.get_shape().as_list()) + + if verbose: + print layer + print 'output size:', np.prod(layer.get_shape().as_list()[1:]), '\n' + + return layer + + +def decoder_with_convs_only(in_signal, n_filters, filter_sizes, strides, padding='same', b_norm=True, non_linearity=tf.nn.relu, + conv_op=conv_1d, regularizer=None, weight_decay=0.001, dropout_prob=None, upsample_sizes=None, + b_norm_finish=False, scope=None, reuse=False, verbose=False): + + if verbose: + print 'Building Decoder' + + n_layers = len(n_filters) + filter_sizes = replicate_parameter_for_all_layers(filter_sizes, n_layers) + strides = replicate_parameter_for_all_layers(strides, n_layers) + dropout_prob = replicate_parameter_for_all_layers(dropout_prob, n_layers) + + for i in xrange(n_layers): + if i == 0: + layer = in_signal + + name = 'decoder_conv_layer_' + str(i) + scope_i = expand_scope_by_name(scope, name) + + layer = conv_op(layer, nb_filter=n_filters[i], filter_size=filter_sizes[i], + strides=strides[i], padding=padding, regularizer=regularizer, weight_decay=weight_decay, + name=name, reuse=reuse, scope=scope_i) + + if verbose: + print name, 'conv params = ', np.prod(layer.W.get_shape().as_list()) + np.prod(layer.b.get_shape().as_list()), + + if (b_norm and i < n_layers - 1) or (i == n_layers - 1 and b_norm_finish): + name += '_bnorm' + scope_i = expand_scope_by_name(scope, name) + layer = batch_normalization(layer, name=name, reuse=reuse, scope=scope_i) + if verbose: + print 'bnorm params = ', np.prod(layer.beta.get_shape().as_list()) + np.prod(layer.gamma.get_shape().as_list()) + + if non_linearity is not None and i < n_layers - 1: # Last layer doesn't have a non-linearity. + layer = non_linearity(layer) + + if dropout_prob is not None and dropout_prob[i] > 0: + layer = dropout(layer, 1.0 - dropout_prob[i]) + + if upsample_sizes is not None and upsample_sizes[i] is not None: + layer = tf.tile(layer, multiples=[1, upsample_sizes[i], 1]) + + if verbose: + print layer + print 'output size:', np.prod(layer.get_shape().as_list()[1:]), '\n' + + return layer diff --git a/src/gan.py b/src/gan.py new file mode 100755 index 0000000..8e6ceb8 --- /dev/null +++ b/src/gan.py @@ -0,0 +1,44 @@ +''' +Created on May 3, 2017 + +@author: optas +''' + +import os.path as osp +import warnings +import tensorflow as tf + + +from general_tools.in_out.basics import create_dir + +from .. neural_net import Neural_Net + + +class GAN(Neural_Net): + + def __init__(self, name, graph): + Neural_Net.__init__(self, name, graph) + + def save_model(self, tick): + self.saver.save(self.sess, self.MODEL_SAVER_ID, global_step=tick) + + def restore_model(self, model_path, epoch, verbose=False): + '''Restore all the variables of a saved model. + ''' + self.saver.restore(self.sess, osp.join(model_path, self.MODEL_SAVER_ID + '-' + str(int(epoch)))) + + if self.epoch.eval(session=self.sess) != epoch: + warnings.warn('Loaded model\'s epoch doesn\'t match the requested one.') + else: + if verbose: + print('Model restored in epoch {0}.'.format(epoch)) + + def optimizer(self, learning_rate, beta, loss, var_list): + initial_learning_rate = learning_rate + optimizer = tf.train.AdamOptimizer(initial_learning_rate, beta1=beta).minimize(loss, var_list=var_list) + return optimizer + + def generate(self, n_samples, noise_params): + noise = self.generator_noise_distribution(n_samples, self.noise_dim, **noise_params) + feed_dict = {self.noise: noise} + return self.sess.run([self.generator_out], feed_dict=feed_dict)[0] \ No newline at end of file diff --git a/src/latent_gan.py b/src/latent_gan.py new file mode 100755 index 0000000..e74347f --- /dev/null +++ b/src/latent_gan.py @@ -0,0 +1,99 @@ +''' +Created on April 27, 2017 + +@author: optas +''' +import numpy as np +import time +import tensorflow as tf + +from . gan import GAN + +from .. fundamentals.layers import safe_log +from tflearn import is_training + + +class LatentGAN(GAN): + def __init__(self, name, learning_rate, n_output, noise_dim, discriminator, generator, beta=0.9, gen_kwargs={}, disc_kwargs={}, graph=None): + + self.noise_dim = noise_dim + self.n_output = n_output + self.discriminator = discriminator + self.generator = generator + + GAN.__init__(self, name, graph) + + with tf.variable_scope(name): + + self.noise = tf.placeholder(tf.float32, shape=[None, noise_dim]) # Noise vector. + self.gt_data = tf.placeholder(tf.float32, shape=[None] + self.n_output) # Ground-truth. + + with tf.variable_scope('generator'): + self.generator_out = self.generator(self.noise, self.n_output, **gen_kwargs) + + with tf.variable_scope('discriminator') as scope: + self.real_prob, self.real_logit = self.discriminator(self.gt_data, scope=scope, **disc_kwargs) + self.synthetic_prob, self.synthetic_logit = self.discriminator(self.generator_out, reuse=True, scope=scope, **disc_kwargs) + + self.loss_d = tf.reduce_mean(-tf.log(self.real_prob) - tf.log(1 - self.synthetic_prob)) + self.loss_g = tf.reduce_mean(-tf.log(self.synthetic_prob)) + + #Post ICLR TRY: safe_log + + train_vars = tf.trainable_variables() + + d_params = [v for v in train_vars if v.name.startswith(name + '/discriminator/')] + g_params = [v for v in train_vars if v.name.startswith(name + '/generator/')] + + self.opt_d = self.optimizer(learning_rate, beta, self.loss_d, d_params) + self.opt_g = self.optimizer(learning_rate, beta, self.loss_g, g_params) + self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=None) + self.init = tf.global_variables_initializer() + + # Launch the session + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + self.sess.run(self.init) + + def generator_noise_distribution(self, n_samples, ndims, mu, sigma): + return np.random.normal(mu, sigma, (n_samples, ndims)) + + def _single_epoch_train(self, train_data, batch_size, noise_params): + ''' + see: http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/ + http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/ + ''' + n_examples = train_data.num_examples + epoch_loss_d = 0. + epoch_loss_g = 0. + batch_size = batch_size + n_batches = int(n_examples / batch_size) + start_time = time.time() + + is_training(True, session=self.sess) + try: + # Loop over all batches + for _ in xrange(n_batches): + feed, _, _ = train_data.next_batch(batch_size) + + # Update discriminator. + z = self.generator_noise_distribution(batch_size, self.noise_dim, **noise_params) + feed_dict = {self.gt_data: feed, self.noise: z} + loss_d, _ = self.sess.run([self.loss_d, self.opt_d], feed_dict=feed_dict) + loss_g, _ = self.sess.run([self.loss_g, self.opt_g], feed_dict=feed_dict) + + # Compute average loss + epoch_loss_d += loss_d + epoch_loss_g += loss_g + + is_training(False, session=self.sess) + except Exception: + raise + finally: + is_training(False, session=self.sess) + + epoch_loss_d /= n_batches + epoch_loss_g /= n_batches + duration = time.time() - start_time + return (epoch_loss_d, epoch_loss_g), duration diff --git a/src/neural_net.py b/src/neural_net.py new file mode 100755 index 0000000..ea78d01 --- /dev/null +++ b/src/neural_net.py @@ -0,0 +1,55 @@ +''' +Created on August 28, 2017 + +@author: optas +''' + +import os.path as osp +import tensorflow as tf + +MODEL_SAVER_ID = 'models.ckpt' + + +class Neural_Net(object): + + def __init__(self, name, graph): + if graph is None: + graph = tf.get_default_graph() + # g = tf.Graph() + # with g.as_default(): + self.graph = graph + self.name = name + + with tf.variable_scope(name): + with tf.device('/cpu:0'): + self.epoch = tf.get_variable('epoch', [], initializer=tf.constant_initializer(0), trainable=False) + + def is_training(self): + is_training_op = self.graph.get_collection('is_training') + return self.sess.run(is_training_op)[0] + +# def __init__(self, name, model, trainer, sess): +# ''' +# Constructor +# ''' +# self.model = model +# self.trainer = trainer +# self.sess = sess +# self.train_step = trainer.train_step +# self.saver = tf.train.Saver(tf.global_variables(), scope=name, max_to_keep=None) +# +# def total_loss(self): +# return self.trainer.total_loss +# +# def forward(self, input_tensor): +# return self.model.forward(input_tensor) +# +# def save_model(self, tick): +# self.saver.save(self.sess, MODEL_SAVER_ID, global_step=tick) +# +# def restore_model(self, model_path, tick, verbose=False): +# ''' restore_model. +# +# Restore all the variables of the saved model. +# ''' +# self.saver.restore(self.sess, osp.join(model_path, MODEL_SAVER_ID + '-' + str(int(tick)))) \ No newline at end of file diff --git a/src/point_net_ae.py b/src/point_net_ae.py new file mode 100755 index 0000000..0811702 --- /dev/null +++ b/src/point_net_ae.py @@ -0,0 +1,166 @@ +''' +Created on January 26, 2017 + +@author: optas +''' + +import time +import tensorflow as tf +import socket +import os.path as osp + +from tflearn.layers.conv import conv_1d +from tflearn.layers.core import fully_connected + +from general_tools.in_out.basics import create_dir + +from . autoencoder import AutoEncoder +from . in_out import apply_augmentations +from .. fundamentals.loss import Loss +from .. fundamentals.inspect import count_trainable_parameters + +from external.structural_losses import nn_distance, approx_match, match_cost + + +class PointNetAutoEncoder(AutoEncoder): + ''' + An Auto-Encoder for point-clouds. + ''' + + def __init__(self, name, configuration, graph=None): + c = configuration + self.configuration = c + + AutoEncoder.__init__(self, name, graph, configuration) + + with tf.variable_scope(name): + self.z = c.encoder(self.x, **c.encoder_args) + self.bottleneck_size = int(self.z.get_shape()[1]) + layer = c.decoder(self.z, **c.decoder_args) + if c.exists_and_is_not_none('close_with_tanh'): + layer = tf.nn.tanh(layer) + if c.exists_and_is_not_none('do_completion'): # TODO Re-factor for AP + self.completion = tf.reshape(layer, [-1, c.n_completion[0], c.n_completion[1]]) + self.x_reconstr = tf.concat(1, [self.x, self.completion]) # output is input + `completion` + else: + self.x_reconstr = tf.reshape(layer, [-1, self.n_output[0], self.n_output[1]]) + + self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=c.saver_max_to_keep) + + self._create_loss() + self._setup_optimizer() + + # GPU configuration + if hasattr(c, 'allow_gpu_growth'): + growth = c.allow_gpu_growth + else: + growth = True + + config = tf.ConfigProto() + config.gpu_options.allow_growth = growth + + # Summaries + self.merged_summaries = tf.summary.merge_all() + self.train_writer = tf.summary.FileWriter(osp.join(configuration.train_dir, 'summaries'), self.graph) + + # Initializing the tensor flow variables + self.init = tf.global_variables_initializer() + + # Launch the session + self.sess = tf.Session(config=config) + self.sess.run(self.init) + + def trainable_parameters(self): + return count_trainable_parameters(self.graph, name_space=self.name) + + def _create_loss(self): + c = self.configuration + + if c.loss == 'l2': + self.loss = Loss.l2_loss(self.x_reconstr, self.gt) + elif c.loss == 'chamfer': + cost_p1_p2, _, cost_p2_p1, _ = nn_distance(self.x_reconstr, self.gt) + self.loss = tf.reduce_mean(cost_p1_p2) + tf.reduce_mean(cost_p2_p1) + elif c.loss == 'emd': + match = approx_match(self.x_reconstr, self.gt) + self.loss = tf.reduce_mean(match_cost(self.x_reconstr, self.gt, match)) + + reg_losses = self.graph.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + if c.exists_and_is_not_none('w_reg_alpha'): + w_reg_alpha = c.w_reg_alpha + else: + w_reg_alpha = 1.0 + + for rl in reg_losses: + self.loss += (w_reg_alpha * rl) + + def _setup_optimizer(self): + c = self.configuration + self.lr = c.learning_rate + if hasattr(c, 'exponential_decay'): + self.lr = tf.train.exponential_decay(c.learning_rate, self.epoch, c.decay_steps, decay_rate=0.5, staircase=True, name="learning_rate_decay") + self.lr = tf.maximum(self.lr, 1e-5) + tf.summary.scalar('learning_rate', self.lr) + + self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) + self.train_step = self.optimizer.minimize(self.loss) + + def _single_epoch_train(self, train_data, configuration, only_fw=False): + n_examples = train_data.num_examples + epoch_loss = 0. + batch_size = configuration.batch_size + n_batches = int(n_examples / batch_size) + start_time = time.time() + + if only_fw: + fit = self.reconstruct + else: + fit = self.partial_fit + + # Loop over all batches + for _ in xrange(n_batches): + + if self.is_denoising: + original_data, _, batch_i = train_data.next_batch(batch_size) + if batch_i is None: # In this case the denoising concern only the augmentation. + batch_i = original_data + else: + batch_i, _, _ = train_data.next_batch(batch_size) + + batch_i = apply_augmentations(batch_i, configuration) # This is a new copy of the batch. + + if self.is_denoising: + _, loss = fit(batch_i, original_data) + else: + _, loss = fit(batch_i) + + # Compute average loss + epoch_loss += loss + epoch_loss /= n_batches + duration = time.time() - start_time + return epoch_loss, duration + + def gradient_of_input_wrt_loss(self, in_points, gt_points=None): + if gt_points is None: + gt_points = in_points + return self.sess.run(tf.gradients(self.loss, self.x), feed_dict={self.x: in_points, self.gt: gt_points}) + + def gradient_of_input_wrt_latent_code(self, in_points, code_dims=None): + ''' batching this is ok. but if you add a list of code_dims the problem is on the way the tf.gradient will + gather the gradients from each dimension, i.e., by default it just adds them. This is problematic since for my + research I would need at least the abs sum of them. + ''' + b_size = len(in_points) + n_dims = len(code_dims) + + row_idx = tf.range(b_size, dtype=tf.int32) + row_idx = tf.reshape(tf.tile(row_idx, [n_dims]), [n_dims, -1]) + row_idx = tf.transpose(row_idx) + col_idx = tf.constant(code_dims, dtype=tf.int32) + col_idx = tf.reshape(tf.tile(col_idx, [b_size]), [b_size, -1]) + coords = tf.transpose(tf.pack([row_idx, col_idx])) + + if b_size == 1: + coords = coords[0] + ys = tf.gather_nd(self.z, coords) + return self.sess.run(tf.gradients(ys, self.x), feed_dict={self.x: in_points})[0] diff --git a/src/raw_gan.py b/src/raw_gan.py new file mode 100755 index 0000000..8c86fc6 --- /dev/null +++ b/src/raw_gan.py @@ -0,0 +1,120 @@ +''' +Created on Apr 27, 2017 + +@author: optas +''' + +import numpy as np +import time +import tensorflow as tf +from tflearn import is_training + +from . gan import GAN +from .. fundamentals.layers import safe_log + + +class RawGAN(GAN): + + def __init__(self, name, learning_rate, n_output, noise_dim, discriminator, generator, beta=0.9, gen_kwargs={}, disc_kwargs={}, graph=None): + + self.noise_dim = noise_dim + self.n_output = n_output + out_shape = [None] + self.n_output + self.discriminator = discriminator + self.generator = generator + + GAN.__init__(self, name, graph) + + with tf.variable_scope(name): + + self.noise = tf.placeholder(tf.float32, shape=[None, noise_dim]) # Noise vector. + self.real_pc = tf.placeholder(tf.float32, shape=out_shape) # Ground-truth. + + with tf.variable_scope('generator'): + self.generator_out = self.generator(self.noise, self.n_output[0], **gen_kwargs) + + with tf.variable_scope('discriminator') as scope: + self.real_prob, self.real_logit = self.discriminator(self.real_pc, scope=scope, **disc_kwargs) + self.synthetic_prob, self.synthetic_logit = self.discriminator(self.generator_out, reuse=True, scope=scope, **disc_kwargs) + + self.loss_d = tf.reduce_mean(-safe_log(self.real_prob) - safe_log(1 - self.synthetic_prob)) + self.loss_g = tf.reduce_mean(-safe_log(self.synthetic_prob)) + + train_vars = tf.trainable_variables() + + d_params = [v for v in train_vars if v.name.startswith(name + '/discriminator/')] + g_params = [v for v in train_vars if v.name.startswith(name + '/generator/')] + + self.opt_d = self.optimizer(learning_rate, beta, self.loss_d, d_params) + self.opt_g = self.optimizer(learning_rate, beta, self.loss_g, g_params) + self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=None) + self.init = tf.global_variables_initializer() + + # Launch the session + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config) + self.sess.run(self.init) + + def generator_noise_distribution(self, n_samples, ndims, mu, sigma): + return np.random.normal(mu, sigma, (n_samples, ndims)) + + def _single_epoch_train(self, train_data, batch_size, noise_params={}, adaptive=None): + ''' + see: http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/ + http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/ + ''' + n_examples = train_data.num_examples + epoch_loss_d = 0. + epoch_loss_g = 0. + batch_size = batch_size + n_batches = int(n_examples / batch_size) + start_time = time.time() + updated_d = 0 + # Loop over all batches + _real_s = [] + _fake_s = [] + is_training(True, session=self.sess) + try: + for _ in xrange(n_batches): + feed, _, _ = train_data.next_batch(batch_size) + # Update discriminator. + z = self.generator_noise_distribution(batch_size, self.noise_dim, **noise_params) + feed_dict = {self.real_pc: feed, self.noise: z} + if adaptive is not None: + s1 = tf.reduce_mean(self.real_prob) + s2 = tf.reduce_mean(1 - self.synthetic_prob) + sr, sf = self.sess.run([s1, s2], feed_dict=feed_dict) + _real_s.append(sr) + _fake_s.append(sf) + if np.mean([sr, sf]) < adaptive: + loss_d, _ = self.sess.run([self.loss_d, self.opt_d], feed_dict=feed_dict) + updated_d += 1 + epoch_loss_d += loss_d + else: + loss_d, _ = self.sess.run([self.loss_d, self.opt_d], feed_dict=feed_dict) + updated_d += 1 + epoch_loss_d += loss_d + # Update generator. + loss_g, _ = self.sess.run([self.loss_g, self.opt_g], feed_dict=feed_dict) + # Compute average loss + # epoch_loss_d += loss_d + epoch_loss_g += loss_g + is_training(False, session=self.sess) + except Exception: + raise + finally: + is_training(False, session=self.sess) + +# epoch_loss_d /= n_batches + if updated_d > 1: + epoch_loss_d /= updated_d + else: + print 'Discriminator was not updated in this epoch.' + + if adaptive is not None: + print np.mean(_real_s), np.mean(_fake_s) + + epoch_loss_g /= n_batches + duration = time.time() - start_time + return (epoch_loss_d, epoch_loss_g), duration