Skip to content

Commit

Permalink
adjust learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Dec 24, 2021
1 parent 7e3ee00 commit 6060810
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 5 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def worker(config, absolute_save_dir, seed):
model = UNet(input_dim=data_opt.input_dim, num_classes=data_opt.num_classes, **config["Arch"])
if model_checkpoint:
logger.info(f"loading checkpoint from {model_checkpoint}")
model.load_state_dict(extract_model_state_dict(model_checkpoint), strict=True)
try:
model.load_state_dict(extract_model_state_dict(model_checkpoint), strict=True)
except RuntimeError as e:
# shape mismatch for network.
logger.warning(e)

trainer_name = config["Trainer"]["name"]
is_pretrain = ("pretrain" in trainer_name)
Expand Down
6 changes: 3 additions & 3 deletions opt/acdc.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
labeled_ratios: [ 1, 2, 4, 174, 174 ]
labeled_ratios: [ 1, 2, 4, 174, 174]
pre_max_epoch: 80
ft_max_epoch: 50
num_batches: 200
num_classes: 2
num_classes: 4
input_dim: 1
pre_lr: 0.0000005
ft_lr: 0.0000003
ft_lr: 0.0000001

0 comments on commit 6060810

Please sign in to comment.