Skip to content

Commit

Permalink
remove unnecessary looping
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanie-fu committed May 29, 2024
1 parent 4174cfc commit c324f56
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri
pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters())
pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad)
print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} '
f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params}')
f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params * 100}')

self.criterion = HingeLoss(margin=self.margin, device=device)

Expand Down Expand Up @@ -145,14 +145,6 @@ def on_train_epoch_end(self):
if self.use_lora:
self.__save_lora_weights()

def on_train_start(self):
for extractor in self.perceptual_model.extractor_list:
extractor.model.train()

def on_validation_start(self):
for extractor in self.perceptual_model.extractor_list:
extractor.model.eval()

def on_validation_epoch_start(self):
self.__reset_val_metrics()

Expand All @@ -163,7 +155,7 @@ def on_validation_epoch_end(self):

self.log(f'val_acc_ckpt', score, logger=False)
self.log(f'val_loss_ckpt', loss, logger=False)
# log for tensorboard

self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch)
self.logger.experiment.add_scalar(f'val_loss/', loss, epoch)

Expand All @@ -190,17 +182,14 @@ def __reset_val_metrics(self):
v.reset()

def __prep_lora_model(self):
for extractor in self.perceptual_model.extractor_list:
config = LoraConfig(
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout,
bias='none',
target_modules=['qkv']
)
extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()),
config).to(extractor.device)
extractor.model = extractor_model
config = LoraConfig(
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout,
bias='none',
target_modules=['qkv']
)
self.perceptual_model = get_peft_model(self.perceptual_model, config)

def __prep_linear_model(self):
for extractor in self.perceptual_model.extractor_list:
Expand Down

0 comments on commit c324f56

Please sign in to comment.