Skip to content

Commit

Permalink
mdb
Browse files Browse the repository at this point in the history
  • Loading branch information
Panos committed Nov 27, 2017
1 parent 8d63c82 commit 9429a97
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 33 deletions.
9 changes: 5 additions & 4 deletions src/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@

from tflearn import is_training

from general_tools.in_out.basics import create_dir, pickle_data, unpickle_data
from . in_out import create_dir, pickle_data, unpickle_data

from general_tools.simpletons import iterate_in_chunks

from . in_out import apply_augmentations
from .. neural_net import Neural_Net
from . neural_net import Neural_Net

model_saver_id = 'models.ckpt'


class Configuration():
def __init__(self, n_input, encoder, decoder, encoder_args={}, decoder_args={},
training_epochs=200, batch_size=10, learning_rate=0.001, denoising=False,
saver_step=None, train_dir=None, z_rotate=False, loss='l2', gauss_augment=None,
saver_max_to_keep=None, loss_display_step=1, spatial_trans=False, debug=False,
saver_step=None, train_dir=None, z_rotate=False, loss='chamfer', gauss_augment=None,
saver_max_to_keep=None, loss_display_step=1, debug=False,
n_z=None, n_output=None, latent_vs_recon=1.0, consistent_io=None):

# Parameters for any AE
Expand Down
24 changes: 23 additions & 1 deletion src/in_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import os.path as osp
import re

from six.moves import cPickle

def create_dir(dir_path):
''' Creates a directory (or nested directories) if they don't exist.
Expand All @@ -19,6 +19,28 @@ def create_dir(dir_path):
return dir_path



def pickle_data(file_name, *args):
'''Using (c)Pickle to save multiple python objects in a single file.
'''
myFile = open(file_name, 'wb')
cPickle.dump(len(args), myFile, protocol=2)
for item in args:
cPickle.dump(item, myFile, protocol=2)
myFile.close()


def unpickle_data(file_name):
'''Restore data previously saved with pickle_data().
'''
inFile = open(file_name, 'rb')
size = cPickle.load(inFile)
for _ in xrange(size):
yield cPickle.load(inFile)
inFile.close()



def files_in_subdirs(top_dir, search_pattern):
regex = re.compile(search_pattern)
for path, _, files in os.walk(top_dir):
Expand Down
29 changes: 1 addition & 28 deletions src/neural_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ class Neural_Net(object):
def __init__(self, name, graph):
if graph is None:
graph = tf.get_default_graph()
# g = tf.Graph()
# with g.as_default():

self.graph = graph
self.name = name

Expand All @@ -27,29 +26,3 @@ def __init__(self, name, graph):
def is_training(self):
is_training_op = self.graph.get_collection('is_training')
return self.sess.run(is_training_op)[0]

# def __init__(self, name, model, trainer, sess):
# '''
# Constructor
# '''
# self.model = model
# self.trainer = trainer
# self.sess = sess
# self.train_step = trainer.train_step
# self.saver = tf.train.Saver(tf.global_variables(), scope=name, max_to_keep=None)
#
# def total_loss(self):
# return self.trainer.total_loss
#
# def forward(self, input_tensor):
# return self.model.forward(input_tensor)
#
# def save_model(self, tick):
# self.saver.save(self.sess, MODEL_SAVER_ID, global_step=tick)
#
# def restore_model(self, model_path, tick, verbose=False):
# ''' restore_model.
#
# Restore all the variables of the saved model.
# '''
# self.saver.restore(self.sess, osp.join(model_path, MODEL_SAVER_ID + '-' + str(int(tick))))

0 comments on commit 9429a97

Please sign in to comment.