From 1ecaf6e423c73e7204615421ca0a17dbcb2f46e8 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Thu, 30 May 2024 21:44:30 +0000 Subject: [PATCH] fix model loading --- dreamsim/feature_extraction/vit_wrapper.py | 19 ------- dreamsim/model.py | 16 ++---- requirements.txt | 2 +- training/train.py | 66 ++++++++++++---------- 4 files changed, 42 insertions(+), 61 deletions(-) delete mode 100644 dreamsim/feature_extraction/vit_wrapper.py diff --git a/dreamsim/feature_extraction/vit_wrapper.py b/dreamsim/feature_extraction/vit_wrapper.py deleted file mode 100644 index 793ebe6..0000000 --- a/dreamsim/feature_extraction/vit_wrapper.py +++ /dev/null @@ -1,19 +0,0 @@ -from transformers import PretrainedConfig -from transformers import PreTrainedModel - - -class ViTConfig(PretrainedConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - -class ViTModel(PreTrainedModel): - config_class = ViTConfig - - def __init__(self, model, config): - super().__init__(config) - self.model = model - self.blocks = model.blocks - - def forward(self, x): - return self.model(x) diff --git a/dreamsim/model.py b/dreamsim/model.py index c0f2e8b..e7d3364 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -8,7 +8,6 @@ from .feature_extraction.extractor import ViTExtractor import yaml from peft import PeftModel, LoraConfig, get_peft_model -from .feature_extraction.vit_wrapper import ViTConfig, ViTModel from .config import dreamsim_args, dreamsim_weights import os import zipfile @@ -216,17 +215,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma model_list = dreamsim_args['model_config'][dreamsim_type]['model_type'].split(",") ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir, normalize_embeds=normalize_embeds) - for extractor in ours_model.extractor_list: - lora_config = LoraConfig(**dreamsim_args['lora_config']) - model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config) - extractor.model = model - tag = "" if dreamsim_type == "ensemble" else "single_" + lora_config = LoraConfig(**dreamsim_args['lora_config']) + ours_model = get_peft_model(ours_model, lora_config) + + tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}" if pretrained: - for extractor, name in zip(ours_model.extractor_list, model_list): - load_dir = os.path.join(cache_dir, f"{name}_{tag}lora") - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device) - extractor.model.eval().requires_grad_(False) + load_dir = os.path.join(cache_dir, f"{tag}lora") + ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device) ours_model.eval().requires_grad_(False) diff --git a/requirements.txt b/requirements.txt index b3946bc..01e366b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ lpips numpy open-clip-torch pandas -peft==0.1.0 +peft>=0.4.0 Pillow pytorch-lightning PyYAML diff --git a/training/train.py b/training/train.py index 51de18d..2d6f747 100644 --- a/training/train.py +++ b/training/train.py @@ -11,7 +11,6 @@ import torch from peft import get_peft_model, LoraConfig, PeftModel from dreamsim import PerceptualModel -from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig import os import configargparse @@ -25,6 +24,9 @@ def parse_args(): parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)') parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs') parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints') + parser.add_argument('--save_mode', type=str, default="all", help='whether to save only LoRA adapter weights, ' + 'entire model, or both. Accepted ' + 'options: [adapter_only, entire_model, all]') ## Model options parser.add_argument('--model_type', type=str, default='dino_vitb16', @@ -63,10 +65,11 @@ def parse_args(): class LightningPerceptualModel(pl.LightningModule): - def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1, + def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", + hidden_size: int = 1, lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16, lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1, - load_dir: str = "./models", device: str = "cuda", + load_dir: str = "./models", device: str = "cuda", save_mode: str = "all", **kwargs): super().__init__() self.save_hyperparameters() @@ -83,12 +86,16 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri self.lora_alpha = lora_alpha self.lora_dropout = lora_dropout self.train_data_len = train_data_len + self.save_mode = save_mode + + self.__validate_save_mode() self.started = False self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)} self.__reset_val_metrics() - self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride, + self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, + stride=self.stride, hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir, device=device) if self.use_lora: @@ -99,7 +106,7 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} ' - f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params * 100}') + f'| % Trainable: {pytorch_total_trainable_params / pytorch_total_params * 100}') self.criterion = HingeLoss(margin=self.margin, device=device) @@ -140,7 +147,8 @@ def on_train_epoch_start(self): def on_train_epoch_end(self): epoch = self.current_epoch + 1 if self.started else 0 - self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch) + self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, + epoch) self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch) if self.use_lora: self.__save_lora_weights() @@ -172,10 +180,14 @@ def configure_optimizers(self): return [optimizer] def load_lora_weights(self, checkpoint_root, epoch_load): - for extractor in self.perceptual_model.extractor_list: - load_dir = os.path.join(checkpoint_root, - f'epoch_{epoch_load}_{extractor.model_type}') - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device) + if self.save_mode in {'adapter_only', 'all'}: + load_dir = os.path.join(checkpoint_root, f'epoch_{epoch_load}') + logging.info(f'Loading adapter weights from {load_dir}') + self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, load_dir).to(self.device) + else: + logging.info(f'Loading entire model from {checkpoint_root}') + sd = torch.load(os.path.join(checkpoint_root, f'epoch={epoch_load:02d}.ckpt'))['state_dict'] + self.load_state_dict(sd, strict=True) def __reset_val_metrics(self): for k, v in self.val_metrics.items(): @@ -199,17 +211,14 @@ def __prep_linear_model(self): self.perceptual_model.mlp.requires_grad_(True) def __save_lora_weights(self): - for extractor in self.perceptual_model.extractor_list: - save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, - f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') - extractor.model.save_pretrained(save_dir, safe_serialization=False) - adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) - new_adapters_weights = dict() + if self.save_mode != 'entire_model': + save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, f'epoch_{self.trainer.current_epoch}') + self.perceptual_model.save_pretrained(save_dir) - for k, v in adapters_weights.items(): - new_k = 'base_model.model.' + k - new_adapters_weights[new_k] = v - torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin')) + def __validate_save_mode(self): + save_options = {'adapter_only', 'entire_model', 'all'} + assert self.save_mode in save_options, f'save_mode must be one of {save_options}, got {self.save_mode}' + logging.info(f'Using save mode: {self.save_mode}') def run(args, device): @@ -234,17 +243,18 @@ def run(args, device): val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) + checkpointer = ModelCheckpoint(monitor='val_loss_ckpt', + save_top_k=-1, + save_last=True, + filename='{epoch:02d}', + mode='min') if args.save_mode != 'adapter_only' else None trainer = Trainer(devices=1, accelerator='gpu', log_every_n_steps=10, logger=logger, max_epochs=args.epochs, default_root_dir=exp_dir, - callbacks=ModelCheckpoint(monitor='val_loss_ckpt', - save_top_k=-1, - save_last=True, - filename='{epoch:02d}', - mode='max'), + callbacks=checkpointer, num_sanity_val_steps=0, ) checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}') @@ -269,9 +279,3 @@ def run(args, device): args = parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" run(args, device) - - - - - -