Skip to content

Commit

Permalink
fix model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanie-fu committed May 30, 2024
1 parent acccdb5 commit 1ecaf6e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 61 deletions.
19 changes: 0 additions & 19 deletions dreamsim/feature_extraction/vit_wrapper.py

This file was deleted.

16 changes: 6 additions & 10 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .feature_extraction.extractor import ViTExtractor
import yaml
from peft import PeftModel, LoraConfig, get_peft_model
from .feature_extraction.vit_wrapper import ViTConfig, ViTModel
from .config import dreamsim_args, dreamsim_weights
import os
import zipfile
Expand Down Expand Up @@ -216,17 +215,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma
model_list = dreamsim_args['model_config'][dreamsim_type]['model_type'].split(",")
ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir,
normalize_embeds=normalize_embeds)
for extractor in ours_model.extractor_list:
lora_config = LoraConfig(**dreamsim_args['lora_config'])
model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config)
extractor.model = model

tag = "" if dreamsim_type == "ensemble" else "single_"
lora_config = LoraConfig(**dreamsim_args['lora_config'])
ours_model = get_peft_model(ours_model, lora_config)

tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}"
if pretrained:
for extractor, name in zip(ours_model.extractor_list, model_list):
load_dir = os.path.join(cache_dir, f"{name}_{tag}lora")
extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device)
extractor.model.eval().requires_grad_(False)
load_dir = os.path.join(cache_dir, f"{tag}lora")
ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device)

ours_model.eval().requires_grad_(False)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ lpips
numpy
open-clip-torch
pandas
peft==0.1.0
peft>=0.4.0
Pillow
pytorch-lightning
PyYAML
Expand Down
66 changes: 35 additions & 31 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
from peft import get_peft_model, LoraConfig, PeftModel
from dreamsim import PerceptualModel
from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig
import os
import configargparse

Expand All @@ -25,6 +24,9 @@ def parse_args():
parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)')
parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs')
parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints')
parser.add_argument('--save_mode', type=str, default="all", help='whether to save only LoRA adapter weights, '
'entire model, or both. Accepted '
'options: [adapter_only, entire_model, all]')

## Model options
parser.add_argument('--model_type', type=str, default='dino_vitb16',
Expand Down Expand Up @@ -63,10 +65,11 @@ def parse_args():


class LightningPerceptualModel(pl.LightningModule):
def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1,
def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16",
hidden_size: int = 1,
lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16,
lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1,
load_dir: str = "./models", device: str = "cuda",
load_dir: str = "./models", device: str = "cuda", save_mode: str = "all",
**kwargs):
super().__init__()
self.save_hyperparameters()
Expand All @@ -83,12 +86,16 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.train_data_len = train_data_len
self.save_mode = save_mode

self.__validate_save_mode()

self.started = False
self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)}
self.__reset_val_metrics()

self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride,
self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type,
stride=self.stride,
hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir,
device=device)
if self.use_lora:
Expand All @@ -99,7 +106,7 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri
pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters())
pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad)
print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} '
f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params * 100}')
f'| % Trainable: {pytorch_total_trainable_params / pytorch_total_params * 100}')

self.criterion = HingeLoss(margin=self.margin, device=device)

Expand Down Expand Up @@ -140,7 +147,8 @@ def on_train_epoch_start(self):

def on_train_epoch_end(self):
epoch = self.current_epoch + 1 if self.started else 0
self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch)
self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches,
epoch)
self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch)
if self.use_lora:
self.__save_lora_weights()
Expand Down Expand Up @@ -172,10 +180,14 @@ def configure_optimizers(self):
return [optimizer]

def load_lora_weights(self, checkpoint_root, epoch_load):
for extractor in self.perceptual_model.extractor_list:
load_dir = os.path.join(checkpoint_root,
f'epoch_{epoch_load}_{extractor.model_type}')
extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device)
if self.save_mode in {'adapter_only', 'all'}:
load_dir = os.path.join(checkpoint_root, f'epoch_{epoch_load}')
logging.info(f'Loading adapter weights from {load_dir}')
self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, load_dir).to(self.device)
else:
logging.info(f'Loading entire model from {checkpoint_root}')
sd = torch.load(os.path.join(checkpoint_root, f'epoch={epoch_load:02d}.ckpt'))['state_dict']
self.load_state_dict(sd, strict=True)

def __reset_val_metrics(self):
for k, v in self.val_metrics.items():
Expand All @@ -199,17 +211,14 @@ def __prep_linear_model(self):
self.perceptual_model.mlp.requires_grad_(True)

def __save_lora_weights(self):
for extractor in self.perceptual_model.extractor_list:
save_dir = os.path.join(self.trainer.callbacks[-1].dirpath,
f'epoch_{self.trainer.current_epoch}_{extractor.model_type}')
extractor.model.save_pretrained(save_dir, safe_serialization=False)
adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin'))
new_adapters_weights = dict()
if self.save_mode != 'entire_model':
save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, f'epoch_{self.trainer.current_epoch}')
self.perceptual_model.save_pretrained(save_dir)

for k, v in adapters_weights.items():
new_k = 'base_model.model.' + k
new_adapters_weights[new_k] = v
torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin'))
def __validate_save_mode(self):
save_options = {'adapter_only', 'entire_model', 'all'}
assert self.save_mode in save_options, f'save_mode must be one of {save_options}, got {self.save_mode}'
logging.info(f'Using save mode: {self.save_mode}')


def run(args, device):
Expand All @@ -234,17 +243,18 @@ def run(args, device):
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)

logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False)
checkpointer = ModelCheckpoint(monitor='val_loss_ckpt',
save_top_k=-1,
save_last=True,
filename='{epoch:02d}',
mode='min') if args.save_mode != 'adapter_only' else None
trainer = Trainer(devices=1,
accelerator='gpu',
log_every_n_steps=10,
logger=logger,
max_epochs=args.epochs,
default_root_dir=exp_dir,
callbacks=ModelCheckpoint(monitor='val_loss_ckpt',
save_top_k=-1,
save_last=True,
filename='{epoch:02d}',
mode='max'),
callbacks=checkpointer,
num_sanity_val_steps=0,
)
checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}')
Expand All @@ -269,9 +279,3 @@ def run(args, device):
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
run(args, device)






0 comments on commit 1ecaf6e

Please sign in to comment.