Skip to content

Commit

Permalink
Merge updated parameters for training
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter Baylies committed Jun 4, 2019
1 parent ad6b065 commit c657331
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions train_effnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from keras.layers import Input, LocallyConnected1D, Reshape, Permute, Conv2D, Add, Concatenate
from keras.models import Model, load_model

def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=32):
def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=32, truncation=0.7):
"""
Generates a dataset of 'n' images of shape ('size', 'size', 3) with random seed 'seed'
along with their dlatent vectors W of shape ('n', 512)
Expand Down Expand Up @@ -54,7 +54,7 @@ def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, im
Z = np.random.randn(n*mod_l, Gs.input_shape[1])
W = Gs.components.mapping.run(Z, None, minibatch_size=minibatch_size) # Use mapping network to get unique dlatents for more variation.
dlatent_avg = Gs.get_var('dlatent_avg') # [component]
W = (W[np.newaxis] - dlatent_avg) * np.reshape([1, -1], [-1, 1, 1, 1]) + dlatent_avg # truncation trick and add negative image pair
W = (W[np.newaxis] - dlatent_avg) * np.reshape([truncation, -truncation], [-1, 1, 1, 1]) + dlatent_avg # truncation trick and add negative image pair
W = np.append(W[0], W[1], axis=0)
W = W[:, :mod_r]
W = W.reshape((n*2, model_scale, 512))
Expand All @@ -64,22 +64,22 @@ def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, im
X = preprocess_input(X)
return W, X

def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16):
def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16, truncation=0.7):
"""
Use generate_dataset_main() as a helper function.
Divides requests into batches to save memory.
"""
batch_size = 16
inc = n//batch_size
left = n-((batch_size-1)*inc)
W, X = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size)
W, X = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation)
for i in range(batch_size-2):
aW, aX = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size)
aW, aX = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation)
W = np.append(W, aW, axis=0)
aW = None
X = np.append(X, aX, axis=0)
aX = None
aW, aX = generate_dataset_main(left, save_path, seed, model_res, image_size, minibatch_size)
aW, aX = generate_dataset_main(left, save_path, seed, model_res, image_size, minibatch_size, truncation)
W = np.append(W, aW, axis=0)
aW = None
X = np.append(X, aX, axis=0)
Expand All @@ -95,7 +95,7 @@ def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_s
def is_square(n):
return (n == int(math.sqrt(n) + 0.5)**2)

def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, activation='elu'):
def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, activation='elu', loss='logcosh', optimizer='adam'):

if os.path.exists(save_path):
print('Loading model')
Expand Down Expand Up @@ -178,10 +178,10 @@ def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3,
x = Add()([x, x_init]) # add skip connection
x = Reshape((model_scale, 512))(x) # train against all dlatent values
model = Model(inputs=inp,outputs=x)
model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss.
model.compile(loss=loss, metrics=[], optimizer=optimizer) # By default: adam optimizer, logcosh used for loss.
return model

def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size=10000, test_size=1000, n_epochs=10, max_patience=5, seed=0, minibatch_size=32):
def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size=10000, test_size=1000, n_epochs=10, max_patience=5, seed=0, minibatch_size=32, truncation=0.7):
"""
Finetunes an EfficientNet to predict W from X
Generate batches (X, W) of size 'batch_size', iterates 'n_epochs', and repeat while 'max_patience' is reached
Expand All @@ -192,7 +192,7 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size
# Create a test set
print('Creating test set:')
np.random.seed(seed)
W_test, X_test = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size)
W_test, X_test = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation)

# Iterate on batches of size batch_size
print('Generating training set:')
Expand All @@ -202,7 +202,7 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size
#print('Initial test loss : {:.5f}'.format(loss))
while (patience <= max_patience):
W_train = X_train = None
W_train, X_train = generate_dataset(batch_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size)
W_train, X_train = generate_dataset(batch_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation)
model.fit(X_train, W_train, epochs=n_epochs, verbose=True, batch_size=minibatch_size)
loss = model.evaluate(X_test, W_test, batch_size=minibatch_size)
if loss < best_loss:
Expand All @@ -226,10 +226,13 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size
parser.add_argument('--model_depth', default=1, help='Number of TreeConnect layers to add after EfficientNet', type=int)
parser.add_argument('--model_size', default=1, help='Model size - 0 - small, 1 - medium, 2 - large, or 3 - full size.', type=int)
parser.add_argument('--activation', default='elu', help='Activation function to use after EfficientNet')
parser.add_argument('--optimizer', default='adam', help='Optimizer to use')
parser.add_argument('--loss', default='logcosh', help='Loss function to use')
parser.add_argument('--use_fp16', default=False, help='Use 16-bit floating point', type=bool)
parser.add_argument('--image_size', default=256, help='Size of images for EfficientNet model', type=int)
parser.add_argument('--batch_size', default=2048, help='Batch size for training the EfficientNet model', type=int)
parser.add_argument('--test_size', default=512, help='Batch size for testing the EfficientNet model', type=int)
parser.add_argument('--truncation', default=0.7, help='Generate images using truncation trick', type=float)
parser.add_argument('--max_patience', default=2, help='Number of iterations to wait while test loss does not improve', type=int)
parser.add_argument('--freeze_first', default=False, help='Start training with the pre-trained network frozen, then unfreeze', type=bool)
parser.add_argument('--epochs', default=2, help='Number of training epochs to run for each batch', type=int)
Expand All @@ -248,34 +251,35 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size
K.set_floatx('float16')
K.set_epsilon(1e-4)

model = get_effnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation)

tflib.init_tf()

model = get_effnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation, optimizer=args.optimizer, loss=args.loss)

with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
generator_network, discriminator_network, Gs_network = pickle.load(f)

def load_Gs():
return Gs_network

K.get_session().run(tensorflow.global_variables_initializer())
#K.get_session().run(tensorflow.global_variables_initializer())

if args.freeze_first:
model.layers[1].trainable = False
model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss.
model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer)

model.summary()

if args.freeze_first: # run a training iteration first while pretrained model is frozen, then unfreeze.
finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size)
finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation)
model.layers[1].trainable = True
model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss.
model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer)
model.summary()

if args.loop < 0:
while True:
finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size)
finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation)
else:
count = args.loop
while count > 0:
finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size)
finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation)
count -= 1

0 comments on commit c657331

Please sign in to comment.