From c6573314776767886a7c43937fc7dd26d10dee00 Mon Sep 17 00:00:00 2001 From: Peter Baylies Date: Tue, 4 Jun 2019 14:36:22 -0400 Subject: [PATCH] Merge updated parameters for training --- train_effnet.py | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/train_effnet.py b/train_effnet.py index a0706e4e9..c4182d89a 100644 --- a/train_effnet.py +++ b/train_effnet.py @@ -21,7 +21,7 @@ from keras.layers import Input, LocallyConnected1D, Reshape, Permute, Conv2D, Add, Concatenate from keras.models import Model, load_model -def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=32): +def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=32, truncation=0.7): """ Generates a dataset of 'n' images of shape ('size', 'size', 3) with random seed 'seed' along with their dlatent vectors W of shape ('n', 512) @@ -54,7 +54,7 @@ def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, im Z = np.random.randn(n*mod_l, Gs.input_shape[1]) W = Gs.components.mapping.run(Z, None, minibatch_size=minibatch_size) # Use mapping network to get unique dlatents for more variation. dlatent_avg = Gs.get_var('dlatent_avg') # [component] - W = (W[np.newaxis] - dlatent_avg) * np.reshape([1, -1], [-1, 1, 1, 1]) + dlatent_avg # truncation trick and add negative image pair + W = (W[np.newaxis] - dlatent_avg) * np.reshape([truncation, -truncation], [-1, 1, 1, 1]) + dlatent_avg # truncation trick and add negative image pair W = np.append(W[0], W[1], axis=0) W = W[:, :mod_r] W = W.reshape((n*2, model_scale, 512)) @@ -64,7 +64,7 @@ def generate_dataset_main(n=10000, save_path=None, seed=None, model_res=1024, im X = preprocess_input(X) return W, X -def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16): +def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_size=256, minibatch_size=16, truncation=0.7): """ Use generate_dataset_main() as a helper function. Divides requests into batches to save memory. @@ -72,14 +72,14 @@ def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_s batch_size = 16 inc = n//batch_size left = n-((batch_size-1)*inc) - W, X = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size) + W, X = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation) for i in range(batch_size-2): - aW, aX = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size) + aW, aX = generate_dataset_main(inc, save_path, seed, model_res, image_size, minibatch_size, truncation) W = np.append(W, aW, axis=0) aW = None X = np.append(X, aX, axis=0) aX = None - aW, aX = generate_dataset_main(left, save_path, seed, model_res, image_size, minibatch_size) + aW, aX = generate_dataset_main(left, save_path, seed, model_res, image_size, minibatch_size, truncation) W = np.append(W, aW, axis=0) aW = None X = np.append(X, aX, axis=0) @@ -95,7 +95,7 @@ def generate_dataset(n=10000, save_path=None, seed=None, model_res=1024, image_s def is_square(n): return (n == int(math.sqrt(n) + 0.5)**2) -def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, activation='elu'): +def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, activation='elu', loss='logcosh', optimizer='adam'): if os.path.exists(save_path): print('Loading model') @@ -178,10 +178,10 @@ def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, x = Add()([x, x_init]) # add skip connection x = Reshape((model_scale, 512))(x) # train against all dlatent values model = Model(inputs=inp,outputs=x) - model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss. + model.compile(loss=loss, metrics=[], optimizer=optimizer) # By default: adam optimizer, logcosh used for loss. return model -def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size=10000, test_size=1000, n_epochs=10, max_patience=5, seed=0, minibatch_size=32): +def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size=10000, test_size=1000, n_epochs=10, max_patience=5, seed=0, minibatch_size=32, truncation=0.7): """ Finetunes an EfficientNet to predict W from X Generate batches (X, W) of size 'batch_size', iterates 'n_epochs', and repeat while 'max_patience' is reached @@ -192,7 +192,7 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size # Create a test set print('Creating test set:') np.random.seed(seed) - W_test, X_test = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size) + W_test, X_test = generate_dataset(n=test_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation) # Iterate on batches of size batch_size print('Generating training set:') @@ -202,7 +202,7 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size #print('Initial test loss : {:.5f}'.format(loss)) while (patience <= max_patience): W_train = X_train = None - W_train, X_train = generate_dataset(batch_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size) + W_train, X_train = generate_dataset(batch_size, model_res=model_res, image_size=image_size, seed=seed, minibatch_size=minibatch_size, truncation=truncation) model.fit(X_train, W_train, epochs=n_epochs, verbose=True, batch_size=minibatch_size) loss = model.evaluate(X_test, W_test, batch_size=minibatch_size) if loss < best_loss: @@ -226,10 +226,13 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size parser.add_argument('--model_depth', default=1, help='Number of TreeConnect layers to add after EfficientNet', type=int) parser.add_argument('--model_size', default=1, help='Model size - 0 - small, 1 - medium, 2 - large, or 3 - full size.', type=int) parser.add_argument('--activation', default='elu', help='Activation function to use after EfficientNet') +parser.add_argument('--optimizer', default='adam', help='Optimizer to use') +parser.add_argument('--loss', default='logcosh', help='Loss function to use') parser.add_argument('--use_fp16', default=False, help='Use 16-bit floating point', type=bool) parser.add_argument('--image_size', default=256, help='Size of images for EfficientNet model', type=int) parser.add_argument('--batch_size', default=2048, help='Batch size for training the EfficientNet model', type=int) parser.add_argument('--test_size', default=512, help='Batch size for testing the EfficientNet model', type=int) +parser.add_argument('--truncation', default=0.7, help='Generate images using truncation trick', type=float) parser.add_argument('--max_patience', default=2, help='Number of iterations to wait while test loss does not improve', type=int) parser.add_argument('--freeze_first', default=False, help='Start training with the pre-trained network frozen, then unfreeze', type=bool) parser.add_argument('--epochs', default=2, help='Number of training epochs to run for each batch', type=int) @@ -248,34 +251,35 @@ def finetune_effnet(model, save_path, model_res=1024, image_size=256, batch_size K.set_floatx('float16') K.set_epsilon(1e-4) -model = get_effnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation) - tflib.init_tf() + +model = get_effnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation, optimizer=args.optimizer, loss=args.loss) + with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f: generator_network, discriminator_network, Gs_network = pickle.load(f) def load_Gs(): return Gs_network -K.get_session().run(tensorflow.global_variables_initializer()) +#K.get_session().run(tensorflow.global_variables_initializer()) if args.freeze_first: model.layers[1].trainable = False - model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss. + model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer) model.summary() if args.freeze_first: # run a training iteration first while pretrained model is frozen, then unfreeze. - finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size) + finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation) model.layers[1].trainable = True - model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss. + model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer) model.summary() if args.loop < 0: while True: - finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size) + finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation) else: count = args.loop while count > 0: - finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size) + finetune_effnet(model, args.model_path, model_res=args.model_res, image_size=args.image_size, batch_size=args.batch_size, test_size=args.test_size, max_patience=args.max_patience, n_epochs=args.epochs, seed=args.seed, minibatch_size=args.minibatch_size, truncation=args.truncation) count -= 1