Skip to content

Commit

Permalink
Merge pull request mli0603#59 from mli0603/fix-resume
Browse files Browse the repository at this point in the history
fixed resuming with old model
  • Loading branch information
Max Zhaoshuo Li authored Mar 25, 2022
2 parents 6054799 + 3b13edc commit 1af3db7
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,26 @@ def main(args):
if len(missing) > 0:
print("Missing keys: ", ','.join(missing))
raise Exception("Missing keys.")
unexpected = [k for k in unexpected if 'running_mean' not in k and 'running_var' not in k] # skip bn params
if len(unexpected) > 0:
print("Unexpected keys: ", ','.join(unexpected))
unexpected_filtered = [k for k in unexpected if
'running_mean' not in k and 'running_var' not in k] # skip bn params
if len(unexpected_filtered) > 0:
print("Unexpected keys: ", ','.join(unexpected_filtered))
raise Exception("Unexpected keys.")
print("Pre-trained model successfully loaded.")

# if not ft/inference/eval, load states for optimizer, lr_scheduler, amp and prev best
if not (args.ft or args.inference or args.eval):
args.start_epoch = checkpoint['epoch'] + 1
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
prev_best = checkpoint['best_pred']
if args.apex:
amp.load_state_dict(checkpoint['amp'])
print("Pre-trained optimizer, lr scheduler and stats successfully loaded.")
if len(unexpected) > 0: # loaded checkpoint has bn parameters, legacy resume, skip loading
raise Exception("Resuming legacy model with BN parameters. Not possible due to BN param change. " +
"Do you want to finetune or inference? If so, check your arguments.")
else:
args.start_epoch = checkpoint['epoch'] + 1
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
prev_best = checkpoint['best_pred']
if args.apex:
amp.load_state_dict(checkpoint['amp'])
print("Pre-trained optimizer, lr scheduler and stats successfully loaded.")

# inference
if args.inference:
Expand Down

0 comments on commit 1af3db7

Please sign in to comment.