diff --git a/src/autoencoder.py b/src/autoencoder.py index e449c2a..8aab10a 100755 --- a/src/autoencoder.py +++ b/src/autoencoder.py @@ -11,11 +11,12 @@ from tflearn import is_training -from general_tools.in_out.basics import create_dir, pickle_data, unpickle_data +from . in_out import create_dir, pickle_data, unpickle_data + from general_tools.simpletons import iterate_in_chunks from . in_out import apply_augmentations -from .. neural_net import Neural_Net +from . neural_net import Neural_Net model_saver_id = 'models.ckpt' @@ -23,8 +24,8 @@ class Configuration(): def __init__(self, n_input, encoder, decoder, encoder_args={}, decoder_args={}, training_epochs=200, batch_size=10, learning_rate=0.001, denoising=False, - saver_step=None, train_dir=None, z_rotate=False, loss='l2', gauss_augment=None, - saver_max_to_keep=None, loss_display_step=1, spatial_trans=False, debug=False, + saver_step=None, train_dir=None, z_rotate=False, loss='chamfer', gauss_augment=None, + saver_max_to_keep=None, loss_display_step=1, debug=False, n_z=None, n_output=None, latent_vs_recon=1.0, consistent_io=None): # Parameters for any AE diff --git a/src/in_out.py b/src/in_out.py index 91b0642..ac1a72b 100755 --- a/src/in_out.py +++ b/src/in_out.py @@ -8,7 +8,7 @@ import os import os.path as osp import re - +from six.moves import cPickle def create_dir(dir_path): ''' Creates a directory (or nested directories) if they don't exist. @@ -19,6 +19,28 @@ def create_dir(dir_path): return dir_path + +def pickle_data(file_name, *args): + '''Using (c)Pickle to save multiple python objects in a single file. + ''' + myFile = open(file_name, 'wb') + cPickle.dump(len(args), myFile, protocol=2) + for item in args: + cPickle.dump(item, myFile, protocol=2) + myFile.close() + + +def unpickle_data(file_name): + '''Restore data previously saved with pickle_data(). + ''' + inFile = open(file_name, 'rb') + size = cPickle.load(inFile) + for _ in xrange(size): + yield cPickle.load(inFile) + inFile.close() + + + def files_in_subdirs(top_dir, search_pattern): regex = re.compile(search_pattern) for path, _, files in os.walk(top_dir): diff --git a/src/neural_net.py b/src/neural_net.py index ea78d01..f328cd4 100755 --- a/src/neural_net.py +++ b/src/neural_net.py @@ -15,8 +15,7 @@ class Neural_Net(object): def __init__(self, name, graph): if graph is None: graph = tf.get_default_graph() - # g = tf.Graph() - # with g.as_default(): + self.graph = graph self.name = name @@ -27,29 +26,3 @@ def __init__(self, name, graph): def is_training(self): is_training_op = self.graph.get_collection('is_training') return self.sess.run(is_training_op)[0] - -# def __init__(self, name, model, trainer, sess): -# ''' -# Constructor -# ''' -# self.model = model -# self.trainer = trainer -# self.sess = sess -# self.train_step = trainer.train_step -# self.saver = tf.train.Saver(tf.global_variables(), scope=name, max_to_keep=None) -# -# def total_loss(self): -# return self.trainer.total_loss -# -# def forward(self, input_tensor): -# return self.model.forward(input_tensor) -# -# def save_model(self, tick): -# self.saver.save(self.sess, MODEL_SAVER_ID, global_step=tick) -# -# def restore_model(self, model_path, tick, verbose=False): -# ''' restore_model. -# -# Restore all the variables of the saved model. -# ''' -# self.saver.restore(self.sess, osp.join(model_path, MODEL_SAVER_ID + '-' + str(int(tick)))) \ No newline at end of file