Skip to content

Commit

Permalink
Update initialization for resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
pbaylies committed Jun 4, 2019
1 parent e37faef commit 342ce04
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions train_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def get_resnet_model(save_path, model_res=1024, image_size=256, depth=2, size=0,
x = Reshape((model_scale, 512))(x) # train against all dlatent values
model = Model(inputs=inp,outputs=x)
model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss.

return model

def finetune_resnet(model, save_path, model_res=1024, image_size=256, batch_size=10000, test_size=1000, n_epochs=10, max_patience=5, seed=0, minibatch_size=32):
Expand Down Expand Up @@ -221,17 +220,16 @@ def finetune_resnet(model, save_path, model_res=1024, image_size=256, batch_size
K.set_floatx('float16')
K.set_epsilon(1e-4)

tflib.init_tf()

model = get_resnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation)

tflib.init_tf()
with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
generator_network, discriminator_network, Gs_network = pickle.load(f)

def load_Gs():
return Gs_network

K.get_session().run(tensorflow.global_variables_initializer())

if args.freeze_first:
model.layers[1].trainable = False
model.compile(loss='logcosh', metrics=[], optimizer='adam') # Adam optimizer, logcosh used for loss.
Expand Down

0 comments on commit 342ce04

Please sign in to comment.