Skip to content

Commit

Permalink
Add count parameter to weight averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
pbaylies committed Jun 16, 2019
1 parent 2248f1a commit e9e0666
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def apply_swa_to_checkpoints(models):
parser.add_argument('results_dir', help='Directory with network checkpoints for weight averaging')
parser.add_argument('--filespec', default='network*.pkl', help='The files to average')
parser.add_argument('--output_model', default='network_avg.pkl', help='The averaged model to output')
parser.add_argument('--count', default=6, help='Average the last n checkpoints', type=int)

args, other_args = parser.parse_known_args()
swa_epochs = 6
swa_epochs = args.count
filepath = args.output_model
files = glob.glob(os.path.join(args.results_dir,args.filespec))
if (len(files)>swa_epochs):
Expand Down

0 comments on commit e9e0666

Please sign in to comment.