diff --git a/training/train.py b/training/train.py index 70fd774..51de18d 100644 --- a/training/train.py +++ b/training/train.py @@ -99,7 +99,7 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} ' - f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params}') + f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params * 100}') self.criterion = HingeLoss(margin=self.margin, device=device) @@ -145,14 +145,6 @@ def on_train_epoch_end(self): if self.use_lora: self.__save_lora_weights() - def on_train_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.train() - - def on_validation_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.eval() - def on_validation_epoch_start(self): self.__reset_val_metrics() @@ -163,7 +155,7 @@ def on_validation_epoch_end(self): self.log(f'val_acc_ckpt', score, logger=False) self.log(f'val_loss_ckpt', loss, logger=False) - # log for tensorboard + self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch) self.logger.experiment.add_scalar(f'val_loss/', loss, epoch) @@ -190,17 +182,14 @@ def __reset_val_metrics(self): v.reset() def __prep_lora_model(self): - for extractor in self.perceptual_model.extractor_list: - config = LoraConfig( - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout, - bias='none', - target_modules=['qkv'] - ) - extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()), - config).to(extractor.device) - extractor.model = extractor_model + config = LoraConfig( + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + bias='none', + target_modules=['qkv'] + ) + self.perceptual_model = get_peft_model(self.perceptual_model, config) def __prep_linear_model(self): for extractor in self.perceptual_model.extractor_list: