From 7cfa2a521c81cad5be2f504d7bf51a0af4b304f9 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Sun, 18 Feb 2024 19:50:04 +0000 Subject: [PATCH 01/15] add distillation and new backbones --- configs/distill_lora.yaml | 21 ++ configs/train_single_model_lora.yaml | 10 +- dataset/dataset.py | 4 +- dataset/download_dataset.sh | 4 +- dataset/nights | 1 + dreamsim/feature_extraction/extractor.py | 11 +- .../feature_extraction/load_synclr_as_dino.py | 16 + dreamsim/model.py | 9 +- training/distill.py | 355 ++++++++++++++++++ training/evaluate.py | 2 +- training/train.py | 19 +- util/precompute_sim.py | 69 ++++ 12 files changed, 503 insertions(+), 18 deletions(-) create mode 100644 configs/distill_lora.yaml create mode 120000 dataset/nights create mode 100644 dreamsim/feature_extraction/load_synclr_as_dino.py create mode 100644 training/distill.py create mode 100644 util/precompute_sim.py diff --git a/configs/distill_lora.yaml b/configs/distill_lora.yaml new file mode 100644 index 0000000..d8e145d --- /dev/null +++ b/configs/distill_lora.yaml @@ -0,0 +1,21 @@ +seed: 1234 +tag: lora_distill +log_dir: /home/fus/repos/dreamsim-dev/output/dev + +model_type: 'clip_vitb32' +feat_type: 'embedding' +stride: '32' +use_lora: True + +dataset_root: ./dataset/nights +num_workers: 4 + +lr: 0.0003 +weight_decay: 0.0 +batch_size: 32 +epochs: 15 +margin: 0.05 + +lora_r: 8 +lora_alpha: 16 +lora_dropout: 0 \ No newline at end of file diff --git a/configs/train_single_model_lora.yaml b/configs/train_single_model_lora.yaml index 5be8cc3..a44955c 100644 --- a/configs/train_single_model_lora.yaml +++ b/configs/train_single_model_lora.yaml @@ -1,9 +1,9 @@ seed: 1234 tag: lora_single -log_dir: ./output +log_dir: ./output/new_backbones -model_type: 'mae_vitb16' -feat_type: 'last_layer' +model_type: 'synclr_vitb16' +feat_type: 'cls' stride: '16' use_lora: True @@ -17,5 +17,5 @@ epochs: 8 margin: 0.05 lora_r: 16 -lora_alpha: 0.5 -lora_dropout: 0.3 \ No newline at end of file +lora_alpha: 16 +lora_dropout: 0.1 diff --git a/dataset/dataset.py b/dataset/dataset.py index c257902..efcf857 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -18,8 +18,8 @@ def __init__(self, root_dir: str, split: str = "train", load_size: int = 224, self.load_size = load_size self.interpolation = interpolation self.preprocess_fn = get_preprocess_fn(preprocess, self.load_size, self.interpolation) - - if self.split == "train" or self.split == "val": + + if self.split == "train" or self.split == "val" or self.split == "test": self.csv = self.csv[self.csv["split"] == split] elif split == 'test_imagenet': self.csv = self.csv[self.csv['split'] == 'test'] diff --git a/dataset/download_dataset.sh b/dataset/download_dataset.sh index 6562f1d..1c8c2db 100644 --- a/dataset/download_dataset.sh +++ b/dataset/download_dataset.sh @@ -1,6 +1,6 @@ #!/bin/bash -mkdir -p ./dataset -cd dataset +# mkdir -p ./dataset +cd /home/fus/data/ # Download NIGHTS dataset wget -O nights.zip https://data.csail.mit.edu/nights/nights.zip diff --git a/dataset/nights b/dataset/nights new file mode 120000 index 0000000..567f57c --- /dev/null +++ b/dataset/nights @@ -0,0 +1 @@ +/home/fus/data/nights \ No newline at end of file diff --git a/dreamsim/feature_extraction/extractor.py b/dreamsim/feature_extraction/extractor.py index c096782..9aa5afa 100644 --- a/dreamsim/feature_extraction/extractor.py +++ b/dreamsim/feature_extraction/extractor.py @@ -7,6 +7,7 @@ import os from .load_clip_as_dino import load_clip_as_dino from .load_open_clip_as_dino import load_open_clip_as_dino +from .load_synclr_as_dino import load_synclr_as_dino from .vision_transformer import DINOHead from .load_mae_as_vit import load_mae_as_vit @@ -62,7 +63,10 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module: :param load_dir: location of pretrained ViT checkpoints. :return: the model """ - if 'dino' in model_type: + if 'dinov2' in model_type: + torch.hub.set_dir(load_dir) + model = torch.hub.load('facebookresearch/dinov2', model_type) + elif 'dino' in model_type: torch.hub.set_dir(load_dir) model = torch.hub.load('facebookresearch/dino:main', model_type) if model_type == 'dino_vitb16': @@ -96,6 +100,11 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module: raise ValueError(f"Model {model_type} not supported") elif 'mae' in model_type: model = load_mae_as_vit(model_type, load_dir) + elif 'synclr' in model_type: + if model_type == 'synclr_vitb16': + model = load_synclr_as_dino(16, load_dir) + else: + raise ValueError(f"Model {model_type} not supported") else: raise ValueError(f"Model {model_type} not supported") return model diff --git a/dreamsim/feature_extraction/load_synclr_as_dino.py b/dreamsim/feature_extraction/load_synclr_as_dino.py new file mode 100644 index 0000000..85c6477 --- /dev/null +++ b/dreamsim/feature_extraction/load_synclr_as_dino.py @@ -0,0 +1,16 @@ +import torch +from .vision_transformer import vit_base, VisionTransformer +import os + + +def load_synclr_as_dino(patch_size, load_dir="./models", l14=False): + sd = torch.load(os.path.join(load_dir, f'synclr_vit_b_{patch_size}.pth'))['model'] + dino_vit = vit_base(patch_size=patch_size) + new_sd = dict() + + for k, v in sd.items(): + new_key = k[14:] # strip "module.visual" from key + new_sd[new_key] = v + + dino_vit.load_state_dict(new_sd) + return dino_vit diff --git a/dreamsim/model.py b/dreamsim/model.py index b84f872..5c8d7d7 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -129,6 +129,8 @@ def _get_mean(self, model_type): return (0.48145466, 0.4578275, 0.40821073) elif "mae" in model_type: return (0.485, 0.456, 0.406) + elif "synclr" in model_type: + return (0.485, 0.456, 0.406) def _get_std(self, model_type): if "dino" in model_type: @@ -139,6 +141,8 @@ def _get_std(self, model_type): return (0.26862954, 0.26130258, 0.27577711) elif "mae" in model_type: return (0.229, 0.224, 0.225) + elif "synclr" in model_type: + return (0.229, 0.224, 0.225) class MLP(torch.nn.Module): @@ -252,6 +256,7 @@ def normalize_embedding(embed): 'dino_vits16': {'cls': 384, 'last_layer': 384}, 'dino_vitb8': {'cls': 768, 'last_layer': 768}, 'dino_vitb16': {'cls': 768, 'last_layer': 768}, + 'dinov2_vitb14': {'cls': 768, 'last_layer': 768}, 'clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 512}, 'clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}, @@ -260,6 +265,6 @@ def normalize_embedding(embed): 'mae_vith14': {'cls': 1280, 'last_layer': 1280}, 'open_clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768}, - 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768} + 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}, + 'synclr_vitb16': {'cls': 768, 'last_layer': 768}, } - diff --git a/training/distill.py b/training/distill.py new file mode 100644 index 0000000..2d0656a --- /dev/null +++ b/training/distill.py @@ -0,0 +1,355 @@ +import logging +import yaml +import pytorch_lightning as pl +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint +from util.train_utils import Mean, HingeLoss, seed_worker +from util.utils import get_preprocess +from dataset.dataset import TwoAFCDataset +from torch.utils.data import DataLoader +import torch +from peft import get_peft_model, LoraConfig, PeftModel +from dreamsim import PerceptualModel +from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig +import os +import configargparse +from tqdm import tqdm +import pickle as pkl +import torch.nn as nn +from sklearn.decomposition import PCA + +torch.autograd.set_detect_anomaly(True) +def parse_args(): + parser = configargparse.ArgumentParser() + parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') + + ## Run options + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)') + parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs') + parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints') + + ## Model options + parser.add_argument('--model_type', type=str, default='dino_vitb16', + help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated' + 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' + 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' + 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') + parser.add_argument('--feat_type', type=str, default='cls', + help='What type of feature to extract from the model. If finetuning an ensemble, pass a ' + 'comma-separated list of features (same length as model_type). Accepted feature types: ' + '[cls, embedding, last_layer].') + parser.add_argument('--stride', type=str, default='16', + help='Stride of first convolution layer the model (should match patch size). If finetuning' + 'an ensemble, pass a comma-separated list (same length as model_type).') + parser.add_argument('--use_lora', type=bool, default=False, + help='Whether to train with LoRA finetuning [True] or with an MLP head [False].') + parser.add_argument('--hidden_size', type=int, default=1, help='Size of the MLP hidden layer.') + + ## Dataset options + parser.add_argument('--dataset_root', type=str, default="./dataset/nights", help='path to training dataset.') + parser.add_argument('--num_workers', type=int, default=4) + + ## Training options + parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training.') + parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay for training.') + parser.add_argument('--batch_size', type=int, default=4, help='Dataset batch size.') + parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.') + parser.add_argument('--margin', default=0.01, type=float, help='Margin for hinge loss') + + ## LoRA-specific options + parser.add_argument('--lora_r', type=int, default=8, help='LoRA attention dimension') + parser.add_argument('--lora_alpha', type=float, default=0.1, help='Alpha for attention scaling') + parser.add_argument('--lora_dropout', type=float, default=0.1, help='Dropout probability for LoRA layers') + + return parser.parse_args() + + +class LightningPerceptualModel(pl.LightningModule): + def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1, + lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16, + lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1, + load_dir: str = "./models", device: str = "cuda", + **kwargs): + super().__init__() + self.save_hyperparameters() + + self.feat_type = feat_type + self.model_type = model_type + self.stride = stride + self.hidden_size = hidden_size + self.lr = lr + self.use_lora = use_lora + self.margin = margin + self.weight_decay = weight_decay + self.lora_r = lora_r + self.lora_alpha = lora_alpha + self.lora_dropout = lora_dropout + self.train_data_len = train_data_len + + self.started = False + self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)} + self.__reset_val_metrics() + + self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride, + hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir, + device=device) + if self.use_lora: + self.__prep_lora_model() + else: + self.__prep_linear_model() + + pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) + pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) + print(pytorch_total_params) + print(pytorch_total_trainable_params) + + self.criterion = nn.L1Loss() + self.teacher = PerceptualModel(feat_type='cls,embedding,embedding', + model_type='dino_vitb16,clip_vitb16,open_clip_vitb16', + stride='16,16,16', + hidden_size=self.hidden_size, lora=False, load_dir=load_dir, + device=device) + + self.epoch_loss_train = 0.0 + self.train_num_correct = 0.0 + + self.automatic_optimization = False + + with open('precomputed_sims.pkl', 'rb') as f: + self.sims = pkl.load(f) + + # with open('precomputed_embeds.pkl', 'rb') as f: + # self.pca = pkl.load(f) + + def forward(self, img_ref, img_0, img_1): + _, embed_0, dist_0 = self.perceptual_model(img_ref, img_0) + embed_ref, embed_1, dist_1 = self.perceptual_model(img_ref, img_1) + return embed_ref, embed_0, embed_1, dist_0, dist_1 + + def training_step(self, batch, batch_idx): + img_ref, img_0, img_1, _, idx = batch + embed_ref, embed_0, embed_1, dist_0, dist_1 = self.forward(img_ref, img_0, img_1) + + # with torch.no_grad(): + # target_embed_ref = self.teacher.embed(img_ref) + # target_embed_0 = self.teacher.embed(img_0) + # target_embed_1 = self.teacher.embed(img_1) + # + # target_embed_ref = torch.tensor(self.pca.transform(target_embed_ref.cpu()), device=img_ref.device).float() + # target_embed_0 = torch.tensor(self.pca.transform(target_embed_0.cpu()), device=img_ref.device).float() + # target_embed_1 = torch.tensor(self.pca.transform(target_embed_1.cpu()), device=img_ref.device).float() + + target_0 = [self.sims['train'][i.item()][0] for i in idx] + target_dist_0 = torch.tensor(target_0, device=img_ref.device) + + target_1 = [self.sims['train'][i.item()][1] for i in idx] + target_dist_1 = torch.tensor(target_1, device=img_ref.device) + + opt = self.optimizers() + opt.zero_grad() + loss_0 = self.criterion(dist_0, target_dist_0) #/ target_dist_0.shape[0] + # loss_ref = self.criterion(embed_ref, target_embed_ref).mean() + # self.manual_backward(loss_0) + + loss_1 = self.criterion(dist_1, target_dist_1) #/ target_dist_1.shape[0] + # loss_0 = self.criterion(embed_0, target_embed_0).mean().float() + # self.manual_backward(loss_1) + + # loss_1 = self.criterion(embed_1, target_embed_1).mean().float() + # self.manual_backward(loss_1) + loss = (loss_0 + loss_1).mean() + self.manual_backward(loss) + opt.step() + + target = torch.lt(target_dist_1, target_dist_0) + decisions = torch.lt(dist_1, dist_0) + + # self.epoch_loss_train += loss_ref + self.epoch_loss_train += loss_0 + self.epoch_loss_train += loss_1 + self.train_num_correct += ((target >= 0.5) == decisions).sum() + return loss_0 + loss_1 + + def validation_step(self, batch, batch_idx): + img_ref, img_0, img_1, _, idx = batch + embed_ref, embed_0, embed_1, dist_0, dist_1 = self.forward(img_ref, img_0, img_1) + + # with torch.no_grad(): + # target_embed_ref = self.teacher.embed(img_ref) + # target_embed_0 = self.teacher.embed(img_0) + # target_embed_1 = self.teacher.embed(img_1) + # + # target_embed_ref = torch.tensor(self.pca.transform(target_embed_ref.cpu()), device=img_ref.device) + # target_embed_0 = torch.tensor(self.pca.transform(target_embed_0.cpu()), device=img_ref.device) + # target_embed_1 = torch.tensor(self.pca.transform(target_embed_1.cpu()), device=img_ref.device) + + target_0 = [self.sims['val'][i.item()][0] for i in idx] + target_1 = [self.sims['val'][i.item()][1] for i in idx] + + target_dist_0 = torch.tensor(target_0, device=img_ref.device) + target_dist_1 = torch.tensor(target_1, device=img_ref.device) + + target = torch.lt(target_dist_1, target_dist_0) + decisions = torch.lt(dist_1, dist_0) + + loss = self.criterion(dist_0, target_dist_0) + loss += self.criterion(dist_1, target_dist_1) + # loss = self.criterion(embed_ref, target_embed_ref).float() + # loss += self.criterion(embed_0, target_embed_0).float() + # loss += self.criterion(embed_1, target_embed_1).float() + loss = loss.mean() + + val_num_correct = ((target >= 0.5) == decisions).sum() + self.val_metrics['loss'].update(loss, target.shape[0]) + self.val_metrics['score'].update(val_num_correct, target.shape[0]) + return loss + + def on_train_epoch_start(self): + self.epoch_loss_train = 0.0 + self.train_num_correct = 0.0 + self.started = True + + def on_train_epoch_end(self): + epoch = self.current_epoch + 1 if self.started else 0 + self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch) + self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch) + if self.use_lora: + self.__save_lora_weights() + + def on_train_start(self): + for extractor in self.perceptual_model.extractor_list: + extractor.model.train() + + def on_validation_start(self): + for extractor in self.perceptual_model.extractor_list: + extractor.model.eval() + + def on_validation_epoch_start(self): + self.__reset_val_metrics() + + def on_validation_epoch_end(self): + epoch = self.current_epoch + 1 if self.started else 0 + score = self.val_metrics['score'].compute() + loss = self.val_metrics['loss'].compute() + + self.log(f'val_acc_ckpt', score, logger=False) + self.log(f'val_loss_ckpt', loss, logger=False) + # log for tensorboard + self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch) + self.logger.experiment.add_scalar(f'val_loss/', loss, epoch) + + return score + + def configure_optimizers(self): + params = list(self.perceptual_model.parameters()) + for extractor in self.perceptual_model.extractor_list: + params += list(extractor.model.parameters()) + for extractor, feat_type in zip(self.perceptual_model.extractor_list, self.perceptual_model.feat_type_list): + if feat_type == 'embedding': + params += [extractor.proj] + optimizer = torch.optim.Adam(params, lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) + return [optimizer] + + def load_lora_weights(self, checkpoint_root, epoch_load): + for extractor in self.perceptual_model.extractor_list: + load_dir = os.path.join(checkpoint_root, + f'epoch_{epoch_load}_{extractor.model_type}') + extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device) + + def __reset_val_metrics(self): + for k, v in self.val_metrics.items(): + v.reset() + + def __prep_lora_model(self): + for extractor in self.perceptual_model.extractor_list: + config = LoraConfig( + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + bias='none', + target_modules=['qkv'] + ) + extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()), + config).to(extractor.device) + extractor.model = extractor_model + + def __prep_linear_model(self): + for extractor in self.perceptual_model.extractor_list: + extractor.model.requires_grad_(False) + if self.feat_type == "embedding": + extractor.proj.requires_grad_(False) + self.perceptual_model.mlp.requires_grad_(True) + + def __save_lora_weights(self): + for extractor in self.perceptual_model.extractor_list: + save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, + f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') + extractor.model.save_pretrained(save_dir, safe_serialization=False) + adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) + new_adapters_weights = dict() + + for k, v in adapters_weights.items(): + new_k = 'base_model.model.' + k + new_adapters_weights[new_k] = v + torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin')) + + +def run(args, device): + tag = args.tag if len(args.tag) > 0 else "" + training_method = "lora" if args.use_lora else "mlp" + exp_dir = os.path.join(args.log_dir, + f'{tag}_{str(args.model_type)}_{str(args.feat_type)}_{str(training_method)}_' + + f'lr_{str(args.lr)}_batchsize_{str(args.batch_size)}_wd_{str(args.weight_decay)}' + f'_hiddensize_{str(args.hidden_size)}_margin_{str(args.margin)}' + ) + if args.use_lora: + exp_dir += f'_lorar_{str(args.lora_r)}_loraalpha_{str(args.lora_alpha)}_loradropout_{str(args.lora_dropout)}' + + seed_everything(args.seed) + g = torch.Generator() + g.manual_seed(args.seed) + + train_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="train", preprocess=get_preprocess(args.model_type)) + val_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="val", preprocess=get_preprocess(args.model_type)) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, + worker_init_fn=seed_worker, generator=g) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + + logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) + trainer = Trainer(devices=1, + accelerator='gpu', + log_every_n_steps=10, + logger=logger, + max_epochs=args.epochs, + default_root_dir=exp_dir, + callbacks=ModelCheckpoint(monitor='val_loss_ckpt', + save_top_k=-1, + save_last=True, + filename='{epoch:02d}', + mode='max'), + num_sanity_val_steps=0, + ) + checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}') + os.makedirs(checkpoint_root, exist_ok=True) + with open(os.path.join(checkpoint_root, 'config.yaml'), 'w') as f: + yaml.dump(args, f) + + logging.basicConfig(filename=os.path.join(checkpoint_root, 'exp.log'), level=logging.INFO, force=True) + logging.info("Arguments: ", vars(args)) + + model = LightningPerceptualModel(device=device, train_data_len=len(train_dataset), **vars(args)) + + logging.info("Validating before training") + trainer.validate(model, val_loader) + logging.info("Training") + trainer.fit(model, train_loader, val_loader) + + print("Done :)") + + +if __name__ == '__main__': + args = parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + run(args, device) diff --git a/training/evaluate.py b/training/evaluate.py index b58502b..b6805e4 100644 --- a/training/evaluate.py +++ b/training/evaluate.py @@ -38,7 +38,7 @@ def parse_args(): help='Which ViT model to evaluate. To evaluate an ensemble of models, pass a comma-separated' 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') + 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]') parser.add_argument('--baseline_feat_type', type=str, help='What type of feature to extract from the model. If evaluating an ensemble, pass a ' 'comma-separated list of features (same length as model_type). Accepted feature types: ' diff --git a/training/train.py b/training/train.py index 1686b3b..1800501 100644 --- a/training/train.py +++ b/training/train.py @@ -16,9 +16,9 @@ import configargparse from tqdm import tqdm -log = logging.getLogger("lightning.pytorch") -log.propagate = False -log.setLevel(logging.INFO) +# log = logging.getLogger("lightning.pytorch") +# log.propagate = False +# log.setLevel(logging.INFO) def parse_args(): @@ -36,7 +36,7 @@ def parse_args(): help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated' 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') + 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]') parser.add_argument('--feat_type', type=str, default='cls', help='What type of feature to extract from the model. If finetuning an ensemble, pass a ' 'comma-separated list of features (same length as model_type). Accepted feature types: ' @@ -101,6 +101,11 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri else: self.__prep_linear_model() + pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) + pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) + print(pytorch_total_params) + print(pytorch_total_trainable_params) + self.criterion = HingeLoss(margin=self.margin, device=device) self.epoch_loss_train = 0.0 @@ -145,6 +150,10 @@ def on_train_epoch_end(self): if self.use_lora: self.__save_lora_weights() + def on_train_start(self): + for extractor in self.perceptual_model.extractor_list: + extractor.model.train() + def on_validation_start(self): for extractor in self.perceptual_model.extractor_list: extractor.model.eval() @@ -209,7 +218,7 @@ def __save_lora_weights(self): for extractor in self.perceptual_model.extractor_list: save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') - extractor.model.save_pretrained(save_dir) + extractor.model.save_pretrained(save_dir, safe_serialization=False) adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) new_adapters_weights = dict() diff --git a/util/precompute_sim.py b/util/precompute_sim.py new file mode 100644 index 0000000..9e162eb --- /dev/null +++ b/util/precompute_sim.py @@ -0,0 +1,69 @@ +from PIL import Image +from lightning_fabric import seed_everything +from torch.utils.data import DataLoader + +from dataset.dataset import TwoAFCDataset +from dreamsim import dreamsim +from torchvision import transforms +import torch +import os +from tqdm import tqdm + +from util.train_utils import seed_worker +from util.utils import get_preprocess +import numpy as np +import pickle as pkl +from sklearn.decomposition import PCA + +seed = 1234 +dataset_root = './dataset/nights' +model_type = 'dino_vitb16,clip_vitb16,open_clip_vitb16' +num_workers = 8 +batch_size = 32 + +seed_everything(seed) +g = torch.Generator() +g.manual_seed(seed) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model, preprocess = dreamsim(pretrained=True, device=device, cache_dir='./models_new') + +train_dataset = TwoAFCDataset(root_dir=dataset_root, split="train", preprocess=get_preprocess(model_type)) +val_dataset = TwoAFCDataset(root_dir=dataset_root, split="val", preprocess=get_preprocess(model_type)) +test_dataset = TwoAFCDataset(root_dir=dataset_root, split="test", preprocess=get_preprocess(model_type)) +train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, + worker_init_fn=seed_worker, generator=g) +val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) +test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) + +data = {'train': {}, 'val': {}, 'test': {}} +all_embeds = [] + +with torch.no_grad(): + for split, loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)]: + for img_ref, img_left, img_right, p, id in tqdm(loader): + img_ref = img_ref.to(device) + img_left = img_left.to(device) + img_right = img_right.to(device) + + embed_ref, embed_0, d0 = model(img_ref, img_left) + _, embed_1, d1 = model(img_ref, img_right) + # + # if split == 'train': + # all_embeds.append(embed_ref) + # all_embeds.append(embed_0) + # all_embeds.append(embed_1) + + for i in range(len(id)): + curr_id = id[i].item() + data[split][curr_id] = [d0[i].item(), d1[i].item()] + +# all_embeds = torch.cat(all_embeds).cpu() +# principal = PCA(n_components=512) +# principal.fit(all_embeds) + +with open('precomputed_sims.pkl', 'wb') as f: + pkl.dump(data, f) + +# with open('precomputed_embeds.pkl', 'wb') as f: +# pkl.dump(principal, f) \ No newline at end of file From bdbfeb03f122e9202611feef25f5b4a05f549530 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Tue, 28 May 2024 20:28:12 +0000 Subject: [PATCH 02/15] update .gitignore --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 68bc17f..ed570f8 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ +output/ +models/ +dataset/nights From 4174cfc80e31472c8cc3288390b1a000f5b2523b Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Wed, 29 May 2024 00:51:27 +0000 Subject: [PATCH 03/15] fix zero trainable parameter bug --- dreamsim/feature_extraction/extractor.py | 3 ++- dreamsim/model.py | 25 ++++++++++++------------ training/train.py | 9 ++------- training/train.sh | 4 ++++ util/constants.py | 5 +++++ 5 files changed, 25 insertions(+), 21 deletions(-) create mode 100644 training/train.sh create mode 100644 util/constants.py diff --git a/dreamsim/feature_extraction/extractor.py b/dreamsim/feature_extraction/extractor.py index 9aa5afa..15900f9 100644 --- a/dreamsim/feature_extraction/extractor.py +++ b/dreamsim/feature_extraction/extractor.py @@ -16,7 +16,7 @@ """ -class ViTExtractor: +class ViTExtractor(nn.Module): """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. We use the following notation in the documentation of the module's methods: @@ -38,6 +38,7 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, load_dir: st :param stride: stride of first convolution layer. small stride -> higher resolution. :param load_dir: location of pretrained ViT checkpoints. """ + super(ViTExtractor, self).__init__() self.model_type = model_type self.device = device self.model = ViTExtractor.create_model(model_type, load_dir) diff --git a/dreamsim/model.py b/dreamsim/model.py index 5c8d7d7..c0f2e8b 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -1,7 +1,10 @@ import torch import torch.nn.functional as F +from torch import nn from torchvision import transforms import os + +from util.constants import * from .feature_extraction.extractor import ViTExtractor import yaml from peft import PeftModel, LoraConfig, get_peft_model @@ -41,7 +44,7 @@ def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stri self.stride_list = [int(x) for x in stride.split(',')] self._validate_args() self.extract_feats_list = [] - self.extractor_list = [] + self.extractor_list = nn.ModuleList() self.embed_size = 0 self.hidden_size = hidden_size self.baseline = baseline @@ -122,27 +125,23 @@ def _preprocess(self, img, model_type): def _get_mean(self, model_type): if "dino" in model_type: - return (0.485, 0.456, 0.406) + return IMAGENET_DEFAULT_MEAN elif "open_clip" in model_type: - return (0.48145466, 0.4578275, 0.40821073) + return OPENAI_CLIP_MEAN elif "clip" in model_type: - return (0.48145466, 0.4578275, 0.40821073) + return OPENAI_CLIP_MEAN elif "mae" in model_type: - return (0.485, 0.456, 0.406) - elif "synclr" in model_type: - return (0.485, 0.456, 0.406) + return IMAGENET_DEFAULT_MEAN def _get_std(self, model_type): if "dino" in model_type: - return (0.229, 0.224, 0.225) + return IMAGENET_DEFAULT_STD elif "open_clip" in model_type: - return (0.26862954, 0.26130258, 0.27577711) + return OPENAI_CLIP_STD elif "clip" in model_type: - return (0.26862954, 0.26130258, 0.27577711) + return OPENAI_CLIP_STD elif "mae" in model_type: - return (0.229, 0.224, 0.225) - elif "synclr" in model_type: - return (0.229, 0.224, 0.225) + return IMAGENET_DEFAULT_STD class MLP(torch.nn.Module): diff --git a/training/train.py b/training/train.py index 1800501..70fd774 100644 --- a/training/train.py +++ b/training/train.py @@ -14,11 +14,6 @@ from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig import os import configargparse -from tqdm import tqdm - -# log = logging.getLogger("lightning.pytorch") -# log.propagate = False -# log.setLevel(logging.INFO) def parse_args(): @@ -103,8 +98,8 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) - print(pytorch_total_params) - print(pytorch_total_trainable_params) + print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} ' + f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params}') self.criterion = HingeLoss(margin=self.margin, device=device) diff --git a/training/train.sh b/training/train.sh new file mode 100644 index 0000000..133ab55 --- /dev/null +++ b/training/train.sh @@ -0,0 +1,4 @@ +python -m training.train --config configs/train_single_model_lora.yaml --model_type dino_vitb16 --feat_type 'cls' --stride '16' & +CUDA_VISIBLE_DEVICES=1 python -m training.train --config configs/train_single_model_lora.yaml --model_type clip_vitb32 --feat_type 'embedding' --stride '32' & +CUDA_VISIBLE_DEVICES=2 python -m training.train --config configs/train_single_model_lora.yaml --model_type open_clip_vitb32 --feat_type 'embedding' --stride '32' & +CUDA_VISIBLE_DEVICES=3 python -m training.train --config configs/train_ensemble_model_lora.yaml & \ No newline at end of file diff --git a/util/constants.py b/util/constants.py new file mode 100644 index 0000000..a9c6872 --- /dev/null +++ b/util/constants.py @@ -0,0 +1,5 @@ +# use timm names from https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/constants.py +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) From c324f5627acd7e28c3cb2b9e1b77cdf1723b571f Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Wed, 29 May 2024 17:50:11 +0000 Subject: [PATCH 04/15] remove unnecessary looping --- training/train.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/training/train.py b/training/train.py index 70fd774..51de18d 100644 --- a/training/train.py +++ b/training/train.py @@ -99,7 +99,7 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} ' - f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params}') + f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params * 100}') self.criterion = HingeLoss(margin=self.margin, device=device) @@ -145,14 +145,6 @@ def on_train_epoch_end(self): if self.use_lora: self.__save_lora_weights() - def on_train_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.train() - - def on_validation_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.eval() - def on_validation_epoch_start(self): self.__reset_val_metrics() @@ -163,7 +155,7 @@ def on_validation_epoch_end(self): self.log(f'val_acc_ckpt', score, logger=False) self.log(f'val_loss_ckpt', loss, logger=False) - # log for tensorboard + self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch) self.logger.experiment.add_scalar(f'val_loss/', loss, epoch) @@ -190,17 +182,14 @@ def __reset_val_metrics(self): v.reset() def __prep_lora_model(self): - for extractor in self.perceptual_model.extractor_list: - config = LoraConfig( - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout, - bias='none', - target_modules=['qkv'] - ) - extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()), - config).to(extractor.device) - extractor.model = extractor_model + config = LoraConfig( + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + bias='none', + target_modules=['qkv'] + ) + self.perceptual_model = get_peft_model(self.perceptual_model, config) def __prep_linear_model(self): for extractor in self.perceptual_model.extractor_list: From 1ecaf6e423c73e7204615421ca0a17dbcb2f46e8 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Thu, 30 May 2024 21:44:30 +0000 Subject: [PATCH 05/15] fix model loading --- dreamsim/feature_extraction/vit_wrapper.py | 19 ------- dreamsim/model.py | 16 ++---- requirements.txt | 2 +- training/train.py | 66 ++++++++++++---------- 4 files changed, 42 insertions(+), 61 deletions(-) delete mode 100644 dreamsim/feature_extraction/vit_wrapper.py diff --git a/dreamsim/feature_extraction/vit_wrapper.py b/dreamsim/feature_extraction/vit_wrapper.py deleted file mode 100644 index 793ebe6..0000000 --- a/dreamsim/feature_extraction/vit_wrapper.py +++ /dev/null @@ -1,19 +0,0 @@ -from transformers import PretrainedConfig -from transformers import PreTrainedModel - - -class ViTConfig(PretrainedConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - -class ViTModel(PreTrainedModel): - config_class = ViTConfig - - def __init__(self, model, config): - super().__init__(config) - self.model = model - self.blocks = model.blocks - - def forward(self, x): - return self.model(x) diff --git a/dreamsim/model.py b/dreamsim/model.py index c0f2e8b..e7d3364 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -8,7 +8,6 @@ from .feature_extraction.extractor import ViTExtractor import yaml from peft import PeftModel, LoraConfig, get_peft_model -from .feature_extraction.vit_wrapper import ViTConfig, ViTModel from .config import dreamsim_args, dreamsim_weights import os import zipfile @@ -216,17 +215,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma model_list = dreamsim_args['model_config'][dreamsim_type]['model_type'].split(",") ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir, normalize_embeds=normalize_embeds) - for extractor in ours_model.extractor_list: - lora_config = LoraConfig(**dreamsim_args['lora_config']) - model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config) - extractor.model = model - tag = "" if dreamsim_type == "ensemble" else "single_" + lora_config = LoraConfig(**dreamsim_args['lora_config']) + ours_model = get_peft_model(ours_model, lora_config) + + tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}" if pretrained: - for extractor, name in zip(ours_model.extractor_list, model_list): - load_dir = os.path.join(cache_dir, f"{name}_{tag}lora") - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device) - extractor.model.eval().requires_grad_(False) + load_dir = os.path.join(cache_dir, f"{tag}lora") + ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device) ours_model.eval().requires_grad_(False) diff --git a/requirements.txt b/requirements.txt index b3946bc..01e366b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ lpips numpy open-clip-torch pandas -peft==0.1.0 +peft>=0.4.0 Pillow pytorch-lightning PyYAML diff --git a/training/train.py b/training/train.py index 51de18d..2d6f747 100644 --- a/training/train.py +++ b/training/train.py @@ -11,7 +11,6 @@ import torch from peft import get_peft_model, LoraConfig, PeftModel from dreamsim import PerceptualModel -from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig import os import configargparse @@ -25,6 +24,9 @@ def parse_args(): parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)') parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs') parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints') + parser.add_argument('--save_mode', type=str, default="all", help='whether to save only LoRA adapter weights, ' + 'entire model, or both. Accepted ' + 'options: [adapter_only, entire_model, all]') ## Model options parser.add_argument('--model_type', type=str, default='dino_vitb16', @@ -63,10 +65,11 @@ def parse_args(): class LightningPerceptualModel(pl.LightningModule): - def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1, + def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", + hidden_size: int = 1, lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16, lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1, - load_dir: str = "./models", device: str = "cuda", + load_dir: str = "./models", device: str = "cuda", save_mode: str = "all", **kwargs): super().__init__() self.save_hyperparameters() @@ -83,12 +86,16 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri self.lora_alpha = lora_alpha self.lora_dropout = lora_dropout self.train_data_len = train_data_len + self.save_mode = save_mode + + self.__validate_save_mode() self.started = False self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)} self.__reset_val_metrics() - self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride, + self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, + stride=self.stride, hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir, device=device) if self.use_lora: @@ -99,7 +106,7 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} ' - f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params * 100}') + f'| % Trainable: {pytorch_total_trainable_params / pytorch_total_params * 100}') self.criterion = HingeLoss(margin=self.margin, device=device) @@ -140,7 +147,8 @@ def on_train_epoch_start(self): def on_train_epoch_end(self): epoch = self.current_epoch + 1 if self.started else 0 - self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch) + self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, + epoch) self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch) if self.use_lora: self.__save_lora_weights() @@ -172,10 +180,14 @@ def configure_optimizers(self): return [optimizer] def load_lora_weights(self, checkpoint_root, epoch_load): - for extractor in self.perceptual_model.extractor_list: - load_dir = os.path.join(checkpoint_root, - f'epoch_{epoch_load}_{extractor.model_type}') - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device) + if self.save_mode in {'adapter_only', 'all'}: + load_dir = os.path.join(checkpoint_root, f'epoch_{epoch_load}') + logging.info(f'Loading adapter weights from {load_dir}') + self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, load_dir).to(self.device) + else: + logging.info(f'Loading entire model from {checkpoint_root}') + sd = torch.load(os.path.join(checkpoint_root, f'epoch={epoch_load:02d}.ckpt'))['state_dict'] + self.load_state_dict(sd, strict=True) def __reset_val_metrics(self): for k, v in self.val_metrics.items(): @@ -199,17 +211,14 @@ def __prep_linear_model(self): self.perceptual_model.mlp.requires_grad_(True) def __save_lora_weights(self): - for extractor in self.perceptual_model.extractor_list: - save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, - f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') - extractor.model.save_pretrained(save_dir, safe_serialization=False) - adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) - new_adapters_weights = dict() + if self.save_mode != 'entire_model': + save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, f'epoch_{self.trainer.current_epoch}') + self.perceptual_model.save_pretrained(save_dir) - for k, v in adapters_weights.items(): - new_k = 'base_model.model.' + k - new_adapters_weights[new_k] = v - torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin')) + def __validate_save_mode(self): + save_options = {'adapter_only', 'entire_model', 'all'} + assert self.save_mode in save_options, f'save_mode must be one of {save_options}, got {self.save_mode}' + logging.info(f'Using save mode: {self.save_mode}') def run(args, device): @@ -234,17 +243,18 @@ def run(args, device): val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) + checkpointer = ModelCheckpoint(monitor='val_loss_ckpt', + save_top_k=-1, + save_last=True, + filename='{epoch:02d}', + mode='min') if args.save_mode != 'adapter_only' else None trainer = Trainer(devices=1, accelerator='gpu', log_every_n_steps=10, logger=logger, max_epochs=args.epochs, default_root_dir=exp_dir, - callbacks=ModelCheckpoint(monitor='val_loss_ckpt', - save_top_k=-1, - save_last=True, - filename='{epoch:02d}', - mode='max'), + callbacks=checkpointer, num_sanity_val_steps=0, ) checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}') @@ -269,9 +279,3 @@ def run(args, device): args = parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" run(args, device) - - - - - - From 813c978eb7185ae36d3c9bace1099440d3abc61a Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Thu, 30 May 2024 14:49:27 -0700 Subject: [PATCH 06/15] Delete training/train.sh --- training/train.sh | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 training/train.sh diff --git a/training/train.sh b/training/train.sh deleted file mode 100644 index 133ab55..0000000 --- a/training/train.sh +++ /dev/null @@ -1,4 +0,0 @@ -python -m training.train --config configs/train_single_model_lora.yaml --model_type dino_vitb16 --feat_type 'cls' --stride '16' & -CUDA_VISIBLE_DEVICES=1 python -m training.train --config configs/train_single_model_lora.yaml --model_type clip_vitb32 --feat_type 'embedding' --stride '32' & -CUDA_VISIBLE_DEVICES=2 python -m training.train --config configs/train_single_model_lora.yaml --model_type open_clip_vitb32 --feat_type 'embedding' --stride '32' & -CUDA_VISIBLE_DEVICES=3 python -m training.train --config configs/train_ensemble_model_lora.yaml & \ No newline at end of file From 5e6b4a8ba0d2a18dd6a6ade55cec1a22b1f0c6d0 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Thu, 30 May 2024 14:49:35 -0700 Subject: [PATCH 07/15] Delete training/evaluate.py --- training/evaluate.py | 211 ------------------------------------------- 1 file changed, 211 deletions(-) delete mode 100644 training/evaluate.py diff --git a/training/evaluate.py b/training/evaluate.py deleted file mode 100644 index b6805e4..0000000 --- a/training/evaluate.py +++ /dev/null @@ -1,211 +0,0 @@ -from pytorch_lightning import seed_everything -import torch -from dataset.dataset import TwoAFCDataset -from util.utils import get_preprocess -from torch.utils.data import DataLoader -import os -import yaml -import logging -from train import LightningPerceptualModel -from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio -from DISTS_pytorch import DISTS -from dreamsim import PerceptualModel -from tqdm import tqdm -import pickle -import configargparse -from dreamsim import dreamsim - -log = logging.getLogger("lightning.pytorch") -log.propagate = False -log.setLevel(logging.ERROR) - - -def parse_args(): - parser = configargparse.ArgumentParser() - parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') - - ## Run options - parser.add_argument('--seed', type=int, default=1234) - - ## Checkpoint evaluation options - parser.add_argument('--eval_root', type=str, help="Path to experiment directory containing a checkpoint to " - "evaluate and the experiment config.yaml.") - parser.add_argument('--checkpoint_epoch', type=int, help='Epoch number of the checkpoint to evaluate.') - parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints.') - - ## Baseline evaluation options - parser.add_argument('--baseline_model', type=str, - help='Which ViT model to evaluate. To evaluate an ensemble of models, pass a comma-separated' - 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' - 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]') - parser.add_argument('--baseline_feat_type', type=str, - help='What type of feature to extract from the model. If evaluating an ensemble, pass a ' - 'comma-separated list of features (same length as model_type). Accepted feature types: ' - '[cls, embedding, last_layer].') - parser.add_argument('--baseline_stride', type=str, - help='Stride of first convolution layer the model (should match patch size). If finetuning' - 'an ensemble, pass a comma-separated list (same length as model_type).') - parser.add_argument('--baseline_output_path', type=str, help='Path to save evaluation results.') - - ## Dataset options - parser.add_argument('--nights_root', type=str, default='./dataset/nights', help='path to nights dataset.') - parser.add_argument('--num_workers', type=int, default=4) - parser.add_argument('--batch_size', type=int, default=4, help='dataset batch size.') - - return parser.parse_args() - - -def score_nights_dataset(model, test_loader, device): - logging.info("Evaluating NIGHTS dataset.") - d0s = [] - d1s = [] - targets = [] - with torch.no_grad(): - for i, (img_ref, img_left, img_right, target, idx) in tqdm(enumerate(test_loader), total=len(test_loader)): - img_ref, img_left, img_right, target = img_ref.to(device), img_left.to(device), \ - img_right.to(device), target.to(device) - - dist_0 = model(img_ref, img_left) - dist_1 = model(img_ref, img_right) - - if len(dist_0.shape) < 1: - dist_0 = dist_0.unsqueeze(0) - dist_1 = dist_1.unsqueeze(0) - dist_0 = dist_0.unsqueeze(1) - dist_1 = dist_1.unsqueeze(1) - target = target.unsqueeze(1) - - d0s.append(dist_0) - d1s.append(dist_1) - targets.append(target) - - d0s = torch.cat(d0s, dim=0) - d1s = torch.cat(d1s, dim=0) - targets = torch.cat(targets, dim=0) - scores = (d0s < d1s) * (1.0 - targets) + (d1s < d0s) * targets + (d1s == d0s) * 0.5 - twoafc_score = torch.mean(scores, dim=0) - logging.info(f"2AFC score: {str(twoafc_score)}") - return twoafc_score - - -def get_baseline_model(baseline_model, feat_type: str = "cls", stride: str = "16", - load_dir: str = "./models", device: str = "cuda"): - if baseline_model == 'psnr': - def psnr_func(im1, im2): - return -peak_signal_noise_ratio(im1, im2, data_range=1.0, dim=(1, 2, 3), reduction='none') - return psnr_func - - elif baseline_model == 'ssim': - def ssim_func(im1, im2): - return -structural_similarity_index_measure(im1, im2, data_range=1.0, reduction='none') - return ssim_func - - elif baseline_model == 'dists': - dists_metric = DISTS().to(device) - - def dists_func(im1, im2): - distances = dists_metric(im1, im2) - return distances - return dists_func - - elif baseline_model == 'lpips': - import lpips - lpips_fn = lpips.LPIPS(net='alex').eval() - - def lpips_func(im1, im2): - distances = lpips_fn(im1.to(device), im2.to(device)).reshape(-1) - return distances - return lpips_func - - elif 'clip' in baseline_model or 'dino' in baseline_model or "open_clip" in baseline_model or "mae" in baseline_model: - perceptual_model = PerceptualModel(feat_type=feat_type, model_type=baseline_model, stride=stride, - baseline=True, load_dir=load_dir, device=device) - for extractor in perceptual_model.extractor_list: - extractor.model.eval() - return perceptual_model - - elif baseline_model == "dreamsim": - dreamsim_model, preprocess = dreamsim(pretrained=True, cache_dir=load_dir) - return dreamsim_model - - else: - raise ValueError(f"Model {baseline_model} not supported.") - - -def run(args, device): - seed_everything(args.seed) - g = torch.Generator() - g.manual_seed(args.seed) - - if args.checkpoint_epoch is not None: - if args.baseline_model is not None: - raise ValueError("Cannot run baseline evaluation with a checkpoint.") - args_path = os.path.join(args.eval_root, "config.yaml") - logging.basicConfig(filename=os.path.join(args.eval_root, 'eval.log'), level=logging.INFO, force=True) - with open(args_path) as f: - logging.info(f"Loading checkpoint arguments from {args_path}") - eval_args = yaml.load(f, Loader=yaml.Loader) - - eval_args.load_dir = args.load_dir - model = LightningPerceptualModel(**vars(eval_args), device=device) - logging.info(f"Loading checkpoint from {args.eval_root} using epoch {args.checkpoint_epoch}") - - checkpoint_root = os.path.join(args.eval_root, "checkpoints") - checkpoint_path = os.path.join(checkpoint_root, f"epoch={(args.checkpoint_epoch):02d}.ckpt") - sd = torch.load(checkpoint_path) - model.load_state_dict(sd["state_dict"]) - if eval_args.use_lora: - model.load_lora_weights(checkpoint_root=checkpoint_root, epoch_load=args.checkpoint_epoch) - model = model.perceptual_model - for extractor in model.extractor_list: - extractor.model.eval() - model = model.to(device) - output_path = checkpoint_root - model_type = eval_args.model_type - - elif args.baseline_model is not None: - if not os.path.exists(args.baseline_output_path): - os.mkdir(args.baseline_output_path) - logging.basicConfig(filename=os.path.join(args.baseline_output_path, 'eval.log'), level=logging.INFO, - force=True) - model = get_baseline_model(args.baseline_model, args.baseline_feat_type, args.baseline_stride, args.load_dir, - device) - output_path = args.baseline_output_path - model_type = args.baseline_model - - else: - raise ValueError("Must specify one of checkpoint_path or baseline_model") - - eval_results = {} - - test_dataset_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_imagenet", - preprocess=get_preprocess(model_type)) - test_dataset_no_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_no_imagenet", - preprocess=get_preprocess(model_type)) - total_length = len(test_dataset_no_imagenet) + len(test_dataset_imagenet) - test_imagenet_loader = DataLoader(test_dataset_imagenet, batch_size=args.batch_size, - num_workers=args.num_workers, shuffle=False) - test_no_imagenet_loader = DataLoader(test_dataset_no_imagenet, batch_size=args.batch_size, - num_workers=args.num_workers, shuffle=False) - - imagenet_score = score_nights_dataset(model, test_imagenet_loader, device) - no_imagenet_score = score_nights_dataset(model, test_no_imagenet_loader, device) - - eval_results['nights_imagenet'] = imagenet_score.item() - eval_results['nights_no_imagenet'] = no_imagenet_score.item() - eval_results['nights_total'] = (imagenet_score.item() * len(test_dataset_imagenet) + - no_imagenet_score.item() * len(test_dataset_no_imagenet)) / total_length - logging.info(f"Combined 2AFC score: {str(eval_results['nights_total'])}") - - logging.info(f"Saving to {os.path.join(output_path, 'eval_results.pkl')}") - with open(os.path.join(output_path, 'eval_results.pkl'), "wb") as f: - pickle.dump(eval_results, f) - - print("Done :)") - - -if __name__ == "__main__": - args = parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" - run(args, device) \ No newline at end of file From dd35e81867dbe04a5ad83e35923d08df5119088d Mon Sep 17 00:00:00 2001 From: ssundaram21 Date: Fri, 28 Jun 2024 18:10:47 -0400 Subject: [PATCH 08/15] eval pipeline --- configs/eval.yaml | 17 ++++ evaluation/eval_datasets.py | 134 ++++++++++++++++++++++++++++++++ evaluation/eval_percep.py | 149 ++++++++++++++++++++++++++++++++++++ evaluation/eval_util.py | 122 +++++++++++++++++++++++++++++ evaluation/score.py | 136 ++++++++++++++++++++++++++++++++ training/train.py | 10 ++- 6 files changed, 564 insertions(+), 4 deletions(-) create mode 100644 configs/eval.yaml create mode 100644 evaluation/eval_datasets.py create mode 100644 evaluation/eval_percep.py create mode 100644 evaluation/eval_util.py create mode 100644 evaluation/score.py diff --git a/configs/eval.yaml b/configs/eval.yaml new file mode 100644 index 0000000..422b8b3 --- /dev/null +++ b/configs/eval.yaml @@ -0,0 +1,17 @@ +eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/checkpoints/clip_vitb32_lora/" +eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/config.yaml" +load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" + +baseline_model: "clip_vitb32" +baseline_feat_type: "cls" +baseline_stride: "32" + +nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" +bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" +things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" +things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" +df2_root: "/data/vision/phillipi/perception/data/df2_org3/" +df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/evaluation/eval_datasets.py b/evaluation/eval_datasets.py new file mode 100644 index 0000000..aa0f508 --- /dev/null +++ b/evaluation/eval_datasets.py @@ -0,0 +1,134 @@ +from torch.utils.data import Dataset +from util.utils import get_preprocess_fn +from torchvision import transforms +import pandas as pd +import numpy as np +from PIL import Image +import os +from typing import Callable +import torch +import glob + +IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] + + +class ThingsDataset(Dataset): + def __init__(self, root_dir: str, txt_file: str, preprocess: str, load_size: int = 224, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + with open(txt_file, "r") as f: + self.txt = f.readlines() + self.dataset_root = root_dir + self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) + + def __len__(self): + return len(self.txt) + + def __getitem__(self, idx): + im_1, im_2, im_3 = self.txt[idx].split() + + im_1 = Image.open(os.path.join(self.dataset_root, f"{im_1}.png")) + im_2 = Image.open(os.path.join(self.dataset_root, f"{im_2}.png")) + im_3 = Image.open(os.path.join(self.dataset_root, f"{im_3}.png")) + + im_1 = self.preprocess_fn(im_1) + im_2 = self.preprocess_fn(im_2) + im_3 = self.preprocess_fn(im_3) + + return im_1, im_2, im_3 + + + +class BAPPSDataset(Dataset): + def __init__(self, root_dir: str, preprocess: str, load_size: int = 224, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + data_types = ["cnn", "traditional", "color", "deblur", "superres", "frameinterp"] + + self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) + self.judge_paths = [] + self.p0_paths = [] + self.p1_paths = [] + self.ref_paths = [] + + for dt in data_types: + list_dir = os.path.join(os.path.join(root_dir, dt), "judge") + for fname in os.scandir(list_dir): + self.judge_paths.append(os.path.join(list_dir, fname.name)) + self.p0_paths.append(os.path.join(os.path.join(os.path.join(root_dir, dt), "p0"), fname.name.split(".")[0] + ".png")) + self.p1_paths.append( + os.path.join(os.path.join(os.path.join(root_dir, dt), "p1"), fname.name.split(".")[0] + ".png")) + self.ref_paths.append( + os.path.join(os.path.join(os.path.join(root_dir, dt), "ref"), fname.name.split(".")[0] + ".png")) + + def __len__(self): + return len(self.judge_paths) + + def __getitem__(self, idx): + judge = np.load(self.judge_paths[idx]) + im_left = self.preprocess_fn(Image.open(self.p0_paths[idx])) + im_right = self.preprocess_fn(Image.open(self.p1_paths[idx])) + im_ref = self.preprocess_fn(Image.open(self.ref_paths[idx])) + return im_ref, im_left, im_right, judge + +class DF2Dataset(torch.utils.data.Dataset): + def __init__(self, root_dir, split: str, preprocess: str, load_size: int = 224, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + + self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) + # self.preprocess_fn=preprocess + self.paths = get_paths(os.path.join(root_dir, split)) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + im_path = self.paths[idx] + img = Image.open(im_path) + img = self.preprocess_fn(img) + return img, im_path + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + +def get_paths(path): + all_paths = [] + for ext in IMAGE_EXTENSIONS: + all_paths += glob.glob(os.path.join(path, f"**.{ext}")) + return all_paths + +# class ImageDataset(torch.utils.data.Dataset): +# def __init__(self, root, class_to_idx, transform=None, ret_path=False): +# """ +# :param root: Dataset root. Should follow the structure class1/0.jpg...n.jpg, class2/0.jpg...n.jpg +# :param class_to_idx: dictionary mapping the classnames to integers. +# :param transform: +# :param ret_path: boolean indicating whether to return the image path or not (useful for KNN for plotting nearest neighbors) +# """ + +# self.transform = transform +# self.label_to_idx = class_to_idx + +# self.paths = [] +# self.labels = [] +# for cls in class_to_idx: +# cls_paths = get_paths(os.path.join(root, cls)) +# self.paths += cls_paths +# self.labels += [self.label_to_idx[cls] for _ in cls_paths] + +# self.ret_path = ret_path + +# def __len__(self): +# return len(self.paths) + +# def __getitem__(self, idx): +# im_path, label = self.paths[idx], self.labels[idx] +# img = pil_loader(im_path) + +# if self.transform is not None: +# img = self.transform(img) +# if not self.ret_path: +# return img, label +# else: +# return img, label, im_path diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py new file mode 100644 index 0000000..8939c05 --- /dev/null +++ b/evaluation/eval_percep.py @@ -0,0 +1,149 @@ +from pytorch_lightning import seed_everything +import torch +from dataset.dataset import TwoAFCDataset +from util.utils import get_preprocess +from torch.utils.data import DataLoader +import os +import yaml +import logging +from training.train import LightningPerceptualModel +from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset, score_df2_dataset +from evaluation.eval_datasets import ThingsDataset, BAPPSDataset, DF2Dataset +from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio +from DISTS_pytorch import DISTS +from dreamsim import PerceptualModel +from tqdm import tqdm +import pickle +import configargparse +from dreamsim import dreamsim +import clip +from torchvision import transforms + +log = logging.getLogger("lightning.pytorch") +log.propagate = False +log.setLevel(logging.ERROR) + +def parse_args(): + parser = configargparse.ArgumentParser() + parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') + + ## Run options + parser.add_argument('--seed', type=int, default=1234) + + ## Checkpoint evaluation options + parser.add_argument('--eval_checkpoint', type=str, help="Path to a checkpoint root.") + parser.add_argument('--eval_checkpoint_cfg', type=str, help="Path to checkpoint config.") + parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints.') + + ## Baseline evaluation options + parser.add_argument('--baseline_model', type=str, default=None) + parser.add_argument('--baseline_feat_type', type=str, default=None) + parser.add_argument('--baseline_stride', type=str, default=None) + + ## Dataset options + parser.add_argument('--nights_root', type=str, default='./dataset/nights', help='path to nights dataset.') + parser.add_argument('--bapps_root', type=str, default='./dataset/bapps', help='path to bapps images.') + parser.add_argument('--things_root', type=str, default='./dataset/things/things_imgs', help='path to things images.') + parser.add_argument('--things_file', type=str, default='./dataset/things/things_trainset.txt', help='path to things info file.') + parser.add_argument('--df2_root', type=str, default='./dataset/df2', help='path to df2 root.') + parser.add_argument('--df2_gt', type=str, default='./dataset/df2/df2_gt.json', help='path to df2 ground truth json.') + parser.add_argument('--num_workers', type=int, default=16) + parser.add_argument('--batch_size', type=int, default=4, help='dataset batch size.') + + return parser.parse_args() + +def load_dreamsim_model(args, device="cuda"): + with open(os.path.join(args.eval_checkpoint_cfg), "r") as f: + cfg = yaml.load(f, Loader=yaml.Loader) + + model_cfg = vars(cfg) + model_cfg['load_dir'] = args.load_dir + model = LightningPerceptualModel(**model_cfg) + model.load_lora_weights(args.eval_checkpoint) + model = model.perceptual_model.to(device) + preprocess = "DEFAULT" + return model, preprocess + + +def load_baseline_model(args, device="cuda"): + model = PerceptualModel(model_type=args.baseline_model, feat_type=args.baseline_feat_type, stride=args.baseline_stride, baseline=True, load_dir=args.load_dir) + model = model.to(device) + preprocess = "DEFAULT" + return model, preprocess + # clip_transform = transforms.Compose([ + # transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC), + # # transforms.CenterCrop(224), + # lambda x: x.convert('RGB'), + # transforms.ToTensor(), + # transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + # ]) + # model, preprocess = clip.load("ViT-B/32", device=device) + # model.visual.ln_post = torch.nn.Identity() + # return model, clip_transform + +def eval_nights(model, preprocess, device, args): + eval_results = {} + val_dataset = TwoAFCDataset(root_dir=args.nights_root, split="val", preprocess=preprocess) + test_dataset_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_imagenet", preprocess=preprocess) + test_dataset_no_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_no_imagenet", preprocess=preprocess) + total_length = len(test_dataset_no_imagenet) + len(test_dataset_imagenet) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + test_imagenet_loader = DataLoader(test_dataset_imagenet, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + test_no_imagenet_loader = DataLoader(test_dataset_no_imagenet, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + + val_score = score_nights_dataset(model, val_loader, device) + imagenet_score = score_nights_dataset(model, test_imagenet_loader, device) + no_imagenet_score = score_nights_dataset(model, test_no_imagenet_loader, device) + + eval_results['nights_val'] = val_score.item() + eval_results['nights_imagenet'] = imagenet_score.item() + eval_results['nights_no_imagenet'] = no_imagenet_score.item() + eval_results['nights_total'] = (imagenet_score.item() * len(test_dataset_imagenet) + + no_imagenet_score.item() * len(test_dataset_no_imagenet)) / total_length + logging.info(f"Combined 2AFC score: {str(eval_results['nights_total'])}") + return eval_results + +def eval_bapps(model, preprocess, device, args): + test_dataset_bapps = BAPPSDataset(root_dir=args.bapps_root, preprocess=preprocess) + test_loader_bapps = DataLoader(test_dataset_bapps, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + bapps_score = score_bapps_dataset(model, test_loader_bapps, device) + return {"bapps_total": bapps_score} + +def eval_things(model, preprocess, device, args): + test_dataset_things = ThingsDataset(root_dir=args.things_root, txt_file=args.things_file, preprocess=preprocess) + test_loader_things = DataLoader(test_dataset_things, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + things_score = score_things_dataset(model, test_loader_things, device) + return {"things_total": things_score} + +def eval_df2(model, preprocess, device, args): + train_dataset = DF2Dataset(root_dir=args.df2_root, split="gallery", preprocess=preprocess) + test_dataset = DF2Dataset(root_dir=args.df2_root, split="customer", preprocess=preprocess) + train_loader_df2 = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True) + test_loader_df2 = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True) + df2_score = score_df2_dataset(model, train_loader_df2, test_loader_df2, args.df2_gt, device) + return {"df2_total": df2_score} + +def run(args, device): + logging.basicConfig(filename=os.path.join(args.eval_checkpoint, 'eval.log'), level=logging.INFO, force=True) + seed_everything(args.seed) + g = torch.Generator() + g.manual_seed(args.seed) + + eval_model, preprocess = load_dreamsim_model(args) + nights_results = eval_nights(eval_model, preprocess, device, args) + bapps_results = eval_bapps(eval_model, preprocess, device, args) + things_results = eval_things(eval_model, preprocess, device, args) + df2_results = eval_df2(eval_model, preprocess, device, args) + + if "baseline_model" in args: + baseline_model, baseline_preprocess = load_baseline_model(args) + nights_results = eval_nights(baseline_model, baseline_preprocess, device, args) + bapps_results = eval_bapps(baseline_model, baseline_preprocess, device, args) + things_results = eval_things(baseline_model, baseline_preprocess, device, args) + df2_results = eval_df2(baseline_model, baseline_preprocess, device, args) + +if __name__ == '__main__': + args = parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + run(args, device) + \ No newline at end of file diff --git a/evaluation/eval_util.py b/evaluation/eval_util.py new file mode 100644 index 0000000..2ff4167 --- /dev/null +++ b/evaluation/eval_util.py @@ -0,0 +1,122 @@ +from torchvision import transforms +import glob +import os +from scripts.util import rescale + +IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] + +norms = { + "dino": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + "mae": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + "clip": transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + "open_clip": transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + "synclr": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + "resnet": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), +} + +dreamsim_transform = transforms.Compose([ + transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor() + ]) + +dino_transform = transforms.Compose([ + transforms.Resize(256, interpolation=3), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + norms['dino'], + ]) + +dinov2_transform = transforms.Compose([ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + norms['dino'], + ]) + +mae_transform = transforms.Compose([ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + norms['mae'], + ]) + +simclrv2_transform = transforms.Compose([ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + ]) + +synclr_transform = transforms.Compose([ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + norms['synclr'], + ]) + +clip_transform = transforms.Compose([ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + norms['clip'], +]) + +# https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html +resnet_transform = transforms.Compose([ + transforms.Resize(232, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + rescale, + norms['resnet'], +]) + +open_clip_transform = clip_transform + +vanilla_transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor() + ]) + +def get_val_transform(model_type): + if "dino" in model_type: + return dino_transform + elif "mae" in model_type: + return mae_transform + elif "clip" in model_type: + return clip_transform + elif "open_clip" in model_type: + return open_clip_transform + else: + return vanilla_transform + + +def get_train_transform(model_type): + if "mae" in model_type: + norm = norms["mae"] + elif "clip" in model_type: + norm = norms["clip"] + elif "open_clip" in model_type: + norm = norms["open_clip"] + else: + norm = norms["dino"] + + return transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + norm, + ]) + + +def get_paths(path): + all_paths = [] + for ext in IMAGE_EXTENSIONS: + all_paths += glob.glob(os.path.join(path, f"**.{ext}")) + return all_paths diff --git a/evaluation/score.py b/evaluation/score.py new file mode 100644 index 0000000..2dc729b --- /dev/null +++ b/evaluation/score.py @@ -0,0 +1,136 @@ +import torch +import os +from tqdm import tqdm +import logging +import numpy as np +import json +import torch.nn.functional as F + +def score_nights_dataset(model, test_loader, device): + logging.info("Evaluating NIGHTS dataset.") + d0s = [] + d1s = [] + targets = [] + with torch.no_grad(): + for i, (img_ref, img_left, img_right, target, idx) in tqdm(enumerate(test_loader), total=len(test_loader)): + img_ref, img_left, img_right, target = img_ref.to(device), img_left.to(device), \ + img_right.to(device), target.to(device) + + dist_0 = model(img_ref, img_left) + dist_1 = model(img_ref, img_right) + + if len(dist_0.shape) < 1: + dist_0 = dist_0.unsqueeze(0) + dist_1 = dist_1.unsqueeze(0) + dist_0 = dist_0.unsqueeze(1) + dist_1 = dist_1.unsqueeze(1) + target = target.unsqueeze(1) + + d0s.append(dist_0) + d1s.append(dist_1) + targets.append(target) + + d0s = torch.cat(d0s, dim=0) + d1s = torch.cat(d1s, dim=0) + targets = torch.cat(targets, dim=0) + scores = (d0s < d1s) * (1.0 - targets) + (d1s < d0s) * targets + (d1s == d0s) * 0.5 + twoafc_score = torch.mean(scores, dim=0) + print(f"2AFC score: {str(twoafc_score)}") + return twoafc_score + +def score_things_dataset(model, test_loader, device): + logging.info("Evaluating Things dataset.") + count = 0 + with torch.no_grad(): + for i, (img_1, img_2, img_3) in tqdm(enumerate(test_loader), total=len(test_loader)): + img_1, img_2, img_3 = img_1.to(device), img_2.to(device), img_3.to(device) + + dist_1_2 = model(img_1, img_2) + dist_1_3 = model(img_1, img_3) + dist_2_3 = model(img_2, img_3) + + le_1_3 = torch.le(dist_1_2, dist_1_3) + le_2_3 = torch.le(dist_1_2, dist_2_3) + + count += sum(torch.logical_and(le_1_3, le_2_3)) + count = count.detach().cpu().numpy() + accs = count / len(full_dataset) + print(f"Things accs: {str(accs)}") + return accs + +def score_bapps_dataset(model, test_loader, device): + logging.info("Evaluating BAPPS dataset.") + + d0s = [] + d1s = [] + ps = [] + with torch.no_grad(): + for i, (im_ref, im_left, im_right, p) in tqdm(enumerate(test_loader), total=len(test_loader)): + im_ref, im_left, im_right, p = im_ref.to(device), im_left.to(device), im_right.to(device), p.to(device) + d0 = model(im_ref, im_left) + d1 = model(im_ref, im_right) + d0s.append(d0) + d1s.append(d1) + ps.append(p.squeeze()) + d0s = torch.cat(d0s, dim=0) + d1s = torch.cat(d1s, dim=0) + ps = torch.cat(ps, dim=0) + scores = (d0s < d1s) * (1.0 - ps) + (d1s < d0s) * ps + (d1s == d0s) * 0.5 + final_score = torch.mean(scores, dim=0) + print(f"BAPPS score: {str(final_score)}") + return final_score + +def score_df2_dataset(model, train_loader, test_loader, gt_path, device): + + def extract_feats(model, dataloader): + embeds = [] + paths = [] + for im, path in tqdm(dataloader): + im = im.to(device) + paths.append(path) + with torch.no_grad(): + out = model.embed(im).squeeze() + embeds.append(out.to("cpu")) + embeds = torch.vstack(embeds).numpy() + paths = np.concatenate(paths) + return embeds, paths + + train_embeds, train_paths = extract_feats(model, train_loader) + train_embeds = torch.from_numpy(train_embeds).to('cuda') + test_embeds, test_paths = extract_feats(model, test_loader) + test_embeds = torch.from_numpy(test_embeds).to('cuda') + + with open(gt_path, "r") as f: + gt = json.load(f) + + ks = [1, 3, 5] + all_results = {} + + relevant = {k: 0 for k in ks} + retrieved = {k: 0 for k in ks} + recall = {k: 0 for k in ks} + + for i in tqdm(range(test_embeds.shape[0]), total=test_embeds.shape[0]): + sim = F.cosine_similarity(test_embeds[i, :], train_embeds, dim=-1) + ranks = torch.argsort(-sim).cpu() + + query_path = test_paths[i] + total_relevant = len(gt[query_path]) + gt_retrievals = gt[query_path] + for k in ks: + if k > 1: + k_retrieved = int(len([x for x in train_paths[ranks.cpu()[:k]] if x in gt_retrievals]) >0) + else: + k_retrieved = int(train_paths[ranks.cpu()[:k]] in gt_retrievals) + + relevant[k] += total_relevant + retrieved[k] += k_retrieved + + for k in ks: + recall[k] = retrieved[k] / test_embeds.shape[0] + + print(f"DF2 recall@k: {str(recall)}") + return recall + + + diff --git a/training/train.py b/training/train.py index 2d6f747..9dc7767 100644 --- a/training/train.py +++ b/training/train.py @@ -179,11 +179,13 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(params, lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) return [optimizer] - def load_lora_weights(self, checkpoint_root, epoch_load): + def load_lora_weights(self, checkpoint_root, epoch_load=None): if self.save_mode in {'adapter_only', 'all'}: - load_dir = os.path.join(checkpoint_root, f'epoch_{epoch_load}') - logging.info(f'Loading adapter weights from {load_dir}') - self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, load_dir).to(self.device) + if epoch_load is not None: + checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}') + + logging.info(f'Loading adapter weights from {checkpoint_root}') + self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device) else: logging.info(f'Loading entire model from {checkpoint_root}') sd = torch.load(os.path.join(checkpoint_root, f'epoch={epoch_load:02d}.ckpt'))['state_dict'] From 2dc509eca2e68a36223f3269a868b4572d893c10 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Tue, 16 Jul 2024 05:09:49 +0000 Subject: [PATCH 09/15] add new dataset download scripts --- configs/distill_lora.yaml | 21 --------------- configs/train_ensemble_model_lora.yaml | 2 +- configs/train_single_model_lora.yaml | 6 ++--- dataset/download_chunked_dataset.sh | 14 ++++++++++ dataset/download_jnd_dataset.sh | 6 +++++ dataset/download_unfiltered_dataset.sh | 20 ++++++++++++++ dreamsim/config.py | 7 ----- .../feature_extraction/vision_transformer.py | 12 ++++++--- dreamsim/model.py | 27 +++++++++++++------ requirements.txt | 2 +- training/train.py | 4 +-- 11 files changed, 74 insertions(+), 47 deletions(-) delete mode 100644 configs/distill_lora.yaml create mode 100644 dataset/download_chunked_dataset.sh create mode 100644 dataset/download_jnd_dataset.sh create mode 100644 dataset/download_unfiltered_dataset.sh diff --git a/configs/distill_lora.yaml b/configs/distill_lora.yaml deleted file mode 100644 index d8e145d..0000000 --- a/configs/distill_lora.yaml +++ /dev/null @@ -1,21 +0,0 @@ -seed: 1234 -tag: lora_distill -log_dir: /home/fus/repos/dreamsim-dev/output/dev - -model_type: 'clip_vitb32' -feat_type: 'embedding' -stride: '32' -use_lora: True - -dataset_root: ./dataset/nights -num_workers: 4 - -lr: 0.0003 -weight_decay: 0.0 -batch_size: 32 -epochs: 15 -margin: 0.05 - -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0 \ No newline at end of file diff --git a/configs/train_ensemble_model_lora.yaml b/configs/train_ensemble_model_lora.yaml index f74eea9..aad34be 100644 --- a/configs/train_ensemble_model_lora.yaml +++ b/configs/train_ensemble_model_lora.yaml @@ -17,5 +17,5 @@ epochs: 6 margin: 0.05 lora_r: 16 -lora_alpha: 0.5 +lora_alpha: 8 lora_dropout: 0.3 \ No newline at end of file diff --git a/configs/train_single_model_lora.yaml b/configs/train_single_model_lora.yaml index a44955c..05b0888 100644 --- a/configs/train_single_model_lora.yaml +++ b/configs/train_single_model_lora.yaml @@ -2,7 +2,7 @@ seed: 1234 tag: lora_single log_dir: ./output/new_backbones -model_type: 'synclr_vitb16' +model_type: 'dino_vitb16' feat_type: 'cls' stride: '16' use_lora: True @@ -17,5 +17,5 @@ epochs: 8 margin: 0.05 lora_r: 16 -lora_alpha: 16 -lora_dropout: 0.1 +lora_alpha: 32 +lora_dropout: 0.2 diff --git a/dataset/download_chunked_dataset.sh b/dataset/download_chunked_dataset.sh new file mode 100644 index 0000000..197feab --- /dev/null +++ b/dataset/download_chunked_dataset.sh @@ -0,0 +1,14 @@ +#!/bin/bash +mkdir -p ./dataset +cd dataset + +mkdir -p ref +mkdir -p distort + +# Download NIGHTS dataset +for i in $(seq -f "%03g" 0 99); do + wget https://data.csail.mit.edu/nights/nights_chunked/ref/$i.zip + unzip -q $i.zip -d ref/ && rm $i.zip + wget https://data.csail.mit.edu/nights/nights_chunked/distort/$i.zip + unzip -q $i.zip -d distort/ && rm $i.zip +done diff --git a/dataset/download_jnd_dataset.sh b/dataset/download_jnd_dataset.sh new file mode 100644 index 0000000..04db9cc --- /dev/null +++ b/dataset/download_jnd_dataset.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p ./dataset +cd dataset + +# Download JND data for NIGHTS dataset +wget https://data.csail.mit.edu/nights/jnd_votes.csv diff --git a/dataset/download_unfiltered_dataset.sh b/dataset/download_unfiltered_dataset.sh new file mode 100644 index 0000000..8ce87d5 --- /dev/null +++ b/dataset/download_unfiltered_dataset.sh @@ -0,0 +1,20 @@ +#!/bin/bash +mkdir -p ./dataset_unfiltered +cd dataset_unfiltered + +mkdir -p ref +mkdir -p distort + +# Download NIGHTS dataset + +# store those in a list and loop through wget and unzip and rm +for i in {0..99..25} +do + wget https://data.csail.mit.edu/nights/nights_unfiltered/ref_${i}_$(($i+24)).zip + unzip -q ref_${i}_$(($i+24)).zip -d ref + rm ref_*.zip + + wget https://data.csail.mit.edu/nights/nights_unfiltered/distort_${i}_$(($i+24)).zip + unzip -q distort_${i}_$(($i+24)).zip -d distort + rm distort_*.zip +done diff --git a/dreamsim/config.py b/dreamsim/config.py index a143a9e..ede32e0 100644 --- a/dreamsim/config.py +++ b/dreamsim/config.py @@ -25,13 +25,6 @@ "lora": True } }, - "lora_config": { - "r": 16, - "lora_alpha": 0.5, - "lora_dropout": 0.3, - "bias": "none", - "target_modules": ['qkv'] - }, "img_size": 224 } diff --git a/dreamsim/feature_extraction/vision_transformer.py b/dreamsim/feature_extraction/vision_transformer.py index 36143e9..b2fe093 100644 --- a/dreamsim/feature_extraction/vision_transformer.py +++ b/dreamsim/feature_extraction/vision_transformer.py @@ -16,7 +16,7 @@ # This version was taken from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # On Jan 24th, 2022 # Git hash of last commit: 4b96393c4c877d127cff9f077468e4a1cc2b5e2d - + """ Mostly copy-paste from timm library. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import torch.nn as nn trunc_normal_ = lambda *args, **kwargs: None - + def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: @@ -43,6 +43,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ + def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob @@ -121,6 +122,7 @@ def forward(self, x, return_attention=False): class PatchEmbed(nn.Module): """ Image to Patch Embedding """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) @@ -138,6 +140,7 @@ def forward(self, x): class VisionTransformer(nn.Module): """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): @@ -239,7 +242,8 @@ def get_intermediate_layers(self, x, n=1): class DINOHead(nn.Module): - def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, + bottleneck_dim=256): super().__init__() nlayers = max(nlayers, 1) if nlayers == 1: @@ -273,7 +277,7 @@ def forward(self, x): x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) return x - + def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( diff --git a/dreamsim/model.py b/dreamsim/model.py index e7d3364..b79b8bb 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -1,3 +1,5 @@ +import json + import torch import torch.nn.functional as F from torch import nn @@ -7,11 +9,19 @@ from util.constants import * from .feature_extraction.extractor import ViTExtractor import yaml +import peft from peft import PeftModel, LoraConfig, get_peft_model from .config import dreamsim_args, dreamsim_weights import os import zipfile +from packaging import version + +peft_version = version.parse(peft.__version__) +min_version = version.parse('0.2.0') +if peft_version < min_version: + raise RuntimeError(f"DreamSim requires peft version {min_version} or greater. " + "Please update peft with 'pip install --upgrade peft'.") class PerceptualModel(torch.nn.Module): def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stride: str = '16', hidden_size: int = 1, @@ -165,9 +175,8 @@ def download_weights(cache_dir, dreamsim_type): """ dreamsim_required_ckpts = { - "ensemble": ["dino_vitb16_pretrain.pth", "dino_vitb16_lora", - "open_clip_vitb16_pretrain.pth.tar", "open_clip_vitb16_lora", - "clip_vitb16_pretrain.pth.tar", "clip_vitb16_lora"], + "ensemble": ["dino_vitb16_pretrain.pth", "open_clip_vitb16_pretrain.pth.tar", + "clip_vitb16_pretrain.pth.tar", "ensemble_lora"], "dino_vitb16": ["dino_vitb16_pretrain.pth", "dino_vitb16_single_lora"], "open_clip_vitb32": ["open_clip_vitb32_pretrain.pth.tar", "open_clip_vitb32_single_lora"], "clip_vitb32": ["clip_vitb32_pretrain.pth.tar", "clip_vitb32_single_lora"] @@ -216,10 +225,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir, normalize_embeds=normalize_embeds) - lora_config = LoraConfig(**dreamsim_args['lora_config']) + tag = "ensemble_" if dreamsim_type == "ensemble" else f"{model_list[0]}_single_" + + with open(os.path.join(cache_dir, f'{tag}lora', 'adapter_config.json'), 'r') as f: + adapter_config = json.load(f) + lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules'] + lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys}) ours_model = get_peft_model(ours_model, lora_config) - tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}" if pretrained: load_dir = os.path.join(cache_dir, f"{tag}lora") ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device) @@ -251,7 +264,6 @@ def normalize_embedding(embed): 'dino_vits16': {'cls': 384, 'last_layer': 384}, 'dino_vitb8': {'cls': 768, 'last_layer': 768}, 'dino_vitb16': {'cls': 768, 'last_layer': 768}, - 'dinov2_vitb14': {'cls': 768, 'last_layer': 768}, 'clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 512}, 'clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}, @@ -260,6 +272,5 @@ def normalize_embedding(embed): 'mae_vith14': {'cls': 1280, 'last_layer': 1280}, 'open_clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768}, - 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}, - 'synclr_vitb16': {'cls': 768, 'last_layer': 768}, + 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768} } diff --git a/requirements.txt b/requirements.txt index 01e366b..7fa525a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ lpips numpy open-clip-torch pandas -peft>=0.4.0 +peft>=0.2.0 Pillow pytorch-lightning PyYAML diff --git a/training/train.py b/training/train.py index 9dc7767..7814edb 100644 --- a/training/train.py +++ b/training/train.py @@ -33,7 +33,7 @@ def parse_args(): help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated' 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]') + 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') parser.add_argument('--feat_type', type=str, default='cls', help='What type of feature to extract from the model. If finetuning an ensemble, pass a ' 'comma-separated list of features (same length as model_type). Accepted feature types: ' @@ -183,7 +183,7 @@ def load_lora_weights(self, checkpoint_root, epoch_load=None): if self.save_mode in {'adapter_only', 'all'}: if epoch_load is not None: checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}') - + logging.info(f'Loading adapter weights from {checkpoint_root}') self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device) else: From 9ae211b52ac747aa8f1c403cbd0c5ed91c745853 Mon Sep 17 00:00:00 2001 From: Stephanie Fu Date: Tue, 16 Jul 2024 05:10:58 +0000 Subject: [PATCH 10/15] update readme --- README.md | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 9944e47..4dc232e 100644 --- a/README.md +++ b/README.md @@ -22,16 +22,14 @@ DreamSim is a new metric for perceptual image similarity that bridges the gap be * DreamSim variants trained for higher resolutions * Compatibility with the most recent version of PEFT -## 🚀 Updates +## 🚀 Newest Updates +**X/XX/24:** Released new versions of the ensemble and single-branch DreamSim models compatible with `peft>=0.2.0`. -**7/14/23:** Released three variants of DreamSim that each only use one finetuned ViT model instead of the full ensemble. These single-branch models provide a ~3x speedup over the full ensemble. - - -Here's how they compare to the full ensemble on NIGHTS (2AFC agreement): -* **Ensemble:** 96.2% -* **OpenCLIP-ViTB/32:** 95.5% -* **DINO-ViTB/16:** 94.6% -* **CLIP-ViTB/32:** 93.9% +Here's how they perform on the NIGHTS validation set: +* **Ensemble:** 96.9% +* **OpenCLIP-ViTB/32:** 95.6% +* **DINO-ViTB/16:** 95.7% +* **CLIP-ViTB/32:** 95.0% ## Table of Contents * [Requirements](#requirements) @@ -96,7 +94,7 @@ distance = model(img1, img2) # The model takes an RGB image from [0, 1], size ba To run on example images, run `demo.py`. The script should produce distances (0.424, 0.34). -### (new!) Single-branch models +### Single-branch models By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance. To load a single-branch model, use the `dreamsim_type` argument. For example: @@ -153,7 +151,15 @@ DreamSim is trained by fine-tuning on the NIGHTS dataset. For details on the dat Run `./dataset/download_dataset.sh` to download and unzip the NIGHTS dataset into `./dataset/nights`. The unzipped dataset size is 58 GB. -**(new!) Visualize NIGHTS and embeddings with the [Voxel51](https://github.com/voxel51/fiftyone) demo:** +Having trouble with the large file sizes? Run `./dataset/download_chunked_dataset.sh` to download the NIGHTS dataset split into 200 smaller zip files. The output of this script is identical to `download_dataset.sh`. + +### (new!) Download the entire 100k pre-filtered NIGHTS dataset +We only use the 20k unanimous triplets for training and evaluation, but release all 100k triplets (many with few and/or split votes) for research purposes. Run `./dataset/download_unfiltered_dataset.sh` to download and unzip this unfiltered version of the NIGHTS dataset into `./dataset/nights_unfiltered`. The unzipped dataset size is 289 GB. + +### (new!) Download the JND data +Download the just-noticeable difference (JND) votes by running `./dataset/download_jnd_dataset.sh`. The CSV will be downloaded to `./dataset/jnd_votes.csv`. Check out the [Colab](https://colab.research.google.com/drive/1taEOMzFE9g81D9AwH27Uhy2U82tQGAVI?usp=sharing) for an example of loading a JND trial. + +### Visualize NIGHTS and embeddings with the [Voxel51](https://github.com/voxel51/fiftyone) demo: [![FiftyOne](https://img.shields.io/badge/FiftyOne-blue.svg?logo=)](https://try.fiftyone.ai/datasets/nights/samples) ## Experiments From 9fd48bf46a9e0c3fa4488ae32881d597fdf2b12b Mon Sep 17 00:00:00 2001 From: Shobhita Sundaram Date: Tue, 30 Jul 2024 11:41:54 -0400 Subject: [PATCH 11/15] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 4dc232e..ef3edc1 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,10 @@ Current metrics for perceptual image similarity operate at the level of pixels a DreamSim is a new metric for perceptual image similarity that bridges the gap between "low-level" metrics (e.g. LPIPS, PSNR, SSIM) and "high-level" measures (e.g. CLIP). Our model was trained by concatenating CLIP, OpenCLIP, and DINO embeddings, and then finetuning on human perceptual judgements. We gathered these judgements on a dataset of ~20k image triplets, generated by diffusion models. Our model achieves better alignment with human similarity judgements than existing metrics, and can be used for downstream applications such as image retrieval. ## 🕰️ Coming soon -* JND Dataset release +* ✅ JND Dataset release +* ✅ Compatibility with the most recent version of PEFT * Distilled DreamSim models (i.e. smaller models distilled from the main ensemble) * DreamSim variants trained for higher resolutions -* Compatibility with the most recent version of PEFT ## 🚀 Newest Updates **X/XX/24:** Released new versions of the ensemble and single-branch DreamSim models compatible with `peft>=0.2.0`. @@ -36,7 +36,7 @@ Here's how they perform on the NIGHTS validation set: * [Setup](#setup) * [Usage](#usage) * [Quickstart](#quickstart-perceptual-similarity-metric) - * [Single-branch models](#new-single-branch-models) + * [Single-branch models](#single-branch-models) * [Feature extraction](#feature-extraction) * [Image retrieval](#image-retrieval) * [Perceptual loss function](#perceptual-loss-function) @@ -95,9 +95,9 @@ distance = model(img1, img2) # The model takes an RGB image from [0, 1], size ba To run on example images, run `demo.py`. The script should produce distances (0.424, 0.34). ### Single-branch models -By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance. +By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. **The single-branch models provide a ~3x speedup over the ensemble.** -To load a single-branch model, use the `dreamsim_type` argument. For example: +The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance. To load a single-branch model, use the `dreamsim_type` argument. For example: ``` dreamsim_dino_model, preprocess = dreamsim(pretrained=True, dreamsim_type="dino_vitb16") ``` From 1d9790e42e2282fbedb6beba06785d70bd460739 Mon Sep 17 00:00:00 2001 From: ssundaram21 Date: Tue, 30 Jul 2024 13:52:52 -0400 Subject: [PATCH 12/15] updated configs and evaluation pipeline --- configs/eval.yaml | 2 ++ configs/eval_dino.yaml | 19 ++++++++++ configs/eval_ensemble.yaml | 19 ++++++++++ configs/eval_open_clip.yaml | 19 ++++++++++ evaluation/eval_percep.py | 70 ++++++++++++++++++++++--------------- evaluation/score.py | 8 ++--- setup.py | 2 +- 7 files changed, 106 insertions(+), 33 deletions(-) create mode 100644 configs/eval_dino.yaml create mode 100644 configs/eval_ensemble.yaml create mode 100644 configs/eval_open_clip.yaml diff --git a/configs/eval.yaml b/configs/eval.yaml index 422b8b3..7f16b0c 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -1,3 +1,5 @@ +tag: "clip" + eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/checkpoints/clip_vitb32_lora/" eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/config.yaml" load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" diff --git a/configs/eval_dino.yaml b/configs/eval_dino.yaml new file mode 100644 index 0000000..acf0012 --- /dev/null +++ b/configs/eval_dino.yaml @@ -0,0 +1,19 @@ +tag: "dino" + +eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_dino_vitb16_cls_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/checkpoints/dino_vitb16_lora/" +eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_dino_vitb16_cls_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/config.yaml" +load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" + +baseline_model: "dino_vitb16" +baseline_feat_type: "cls" +baseline_stride: "16" + +nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" +bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" +things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" +things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" +df2_root: "/data/vision/phillipi/perception/data/df2_org3/" +df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/configs/eval_ensemble.yaml b/configs/eval_ensemble.yaml new file mode 100644 index 0000000..99217bd --- /dev/null +++ b/configs/eval_ensemble.yaml @@ -0,0 +1,19 @@ +tag: "open_clip" + +eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_open_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/checkpoints/open_clip_vitb32_lora/" +eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_open_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/config.yaml" +load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" + +baseline_model: "open_clip_vitb32" +baseline_feat_type: "embedding" +baseline_stride: "32" + +nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" +bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" +things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" +things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" +df2_root: "/data/vision/phillipi/perception/data/df2_org3/" +df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/configs/eval_open_clip.yaml b/configs/eval_open_clip.yaml new file mode 100644 index 0000000..b689077 --- /dev/null +++ b/configs/eval_open_clip.yaml @@ -0,0 +1,19 @@ +tag: "ensemble" + +eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_ensemble_dino_vitb16,clip_vitb16,open_clip_vitb16_cls,embedding,embedding_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_1.0_loradropout_0.3/lightning_logs/version_0/checkpoints/ensemble_lora/" +eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_ensemble_dino_vitb16,clip_vitb16,open_clip_vitb16_cls,embedding,embedding_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_1.0_loradropout_0.3/lightning_logs/version_0/config.yaml" +load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" + +baseline_model: "dino_vitb16,clip_vitb16,open_clip_vitb16" +baseline_feat_type: "cls,embedding,embedding" +baseline_stride: "16,16,16" + +nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" +bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" +things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" +things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" +df2_root: "/data/vision/phillipi/perception/data/df2_org3/" +df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py index 8939c05..3be4cba 100644 --- a/evaluation/eval_percep.py +++ b/evaluation/eval_percep.py @@ -6,6 +6,7 @@ import os import yaml import logging +import json from training.train import LightningPerceptualModel from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset, score_df2_dataset from evaluation.eval_datasets import ThingsDataset, BAPPSDataset, DF2Dataset @@ -29,6 +30,9 @@ def parse_args(): ## Run options parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--output', type=str, default="./eval_outputs", help="Dir to save results in.") + parser.add_argument('--tag', type=str, default="", help="Exp name for saving results") + ## Checkpoint evaluation options parser.add_argument('--eval_checkpoint', type=str, help="Path to a checkpoint root.") @@ -41,12 +45,12 @@ def parse_args(): parser.add_argument('--baseline_stride', type=str, default=None) ## Dataset options - parser.add_argument('--nights_root', type=str, default='./dataset/nights', help='path to nights dataset.') - parser.add_argument('--bapps_root', type=str, default='./dataset/bapps', help='path to bapps images.') - parser.add_argument('--things_root', type=str, default='./dataset/things/things_imgs', help='path to things images.') - parser.add_argument('--things_file', type=str, default='./dataset/things/things_trainset.txt', help='path to things info file.') - parser.add_argument('--df2_root', type=str, default='./dataset/df2', help='path to df2 root.') - parser.add_argument('--df2_gt', type=str, default='./dataset/df2/df2_gt.json', help='path to df2 ground truth json.') + parser.add_argument('--nights_root', type=str, default=None, help='path to nights dataset.') + parser.add_argument('--bapps_root', type=str, default=None, help='path to bapps images.') + parser.add_argument('--things_root', type=str, default=None, help='path to things images.') + parser.add_argument('--things_file', type=str, default=None, help='path to things info file.') + parser.add_argument('--df2_root', type=str, default=None, help='path to df2 root.') + parser.add_argument('--df2_gt', type=str, default=None, help='path to df2 ground truth json.') parser.add_argument('--num_workers', type=int, default=16) parser.add_argument('--batch_size', type=int, default=4, help='dataset batch size.') @@ -70,16 +74,6 @@ def load_baseline_model(args, device="cuda"): model = model.to(device) preprocess = "DEFAULT" return model, preprocess - # clip_transform = transforms.Compose([ - # transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC), - # # transforms.CenterCrop(224), - # lambda x: x.convert('RGB'), - # transforms.ToTensor(), - # transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - # ]) - # model, preprocess = clip.load("ViT-B/32", device=device) - # model.visual.ln_post = torch.nn.Identity() - # return model, clip_transform def eval_nights(model, preprocess, device, args): eval_results = {} @@ -100,19 +94,24 @@ def eval_nights(model, preprocess, device, args): eval_results['nights_no_imagenet'] = no_imagenet_score.item() eval_results['nights_total'] = (imagenet_score.item() * len(test_dataset_imagenet) + no_imagenet_score.item() * len(test_dataset_no_imagenet)) / total_length - logging.info(f"Combined 2AFC score: {str(eval_results['nights_total'])}") + logging.info(f"NIGHTS (validation 2AFC): {str(eval_results['nights_val'])}") + logging.info(f"NIGHTS (imagenet 2AFC): {str(eval_results['nights_imagenet'])}") + logging.info(f"NIGHTS (no-imagenet 2AFC): {str(eval_results['nights_no_imagenet'])}") + logging.info(f"NIGHTS (total 2AFC): {str(eval_results['nights_total'])}") return eval_results def eval_bapps(model, preprocess, device, args): test_dataset_bapps = BAPPSDataset(root_dir=args.bapps_root, preprocess=preprocess) test_loader_bapps = DataLoader(test_dataset_bapps, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) bapps_score = score_bapps_dataset(model, test_loader_bapps, device) + logging.info(f"BAPPS (total 2AFC): {str(bapps_score)}") return {"bapps_total": bapps_score} def eval_things(model, preprocess, device, args): test_dataset_things = ThingsDataset(root_dir=args.things_root, txt_file=args.things_file, preprocess=preprocess) test_loader_things = DataLoader(test_dataset_things, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) things_score = score_things_dataset(model, test_loader_things, device) + logging.info(f"THINGS (total 2AFC): {things_score}") return {"things_total": things_score} def eval_df2(model, preprocess, device, args): @@ -121,26 +120,41 @@ def eval_df2(model, preprocess, device, args): train_loader_df2 = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True) test_loader_df2 = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True) df2_score = score_df2_dataset(model, train_loader_df2, test_loader_df2, args.df2_gt, device) + logging.info(f"DF2 (total recall@k): {str(recall)}") return {"df2_total": df2_score} +def full_eval(eval_model, preprocess, device, args): + results = {} + if args.nights_root is not None: + results['ckpt_nights'] = eval_nights(eval_model, preprocess, device, args) + if args.bapps_root is not None: + results['ckpt_bapps'] = bapps_results = eval_bapps(eval_model, preprocess, device, args) + if args.things_root is not None: + results['ckpt_things'] = eval_things(eval_model, preprocess, device, args) + if args.df2_root is not None: + results['ckpt_df2_root'] = eval_df2(eval_model, preprocess, device, args) + return results + def run(args, device): logging.basicConfig(filename=os.path.join(args.eval_checkpoint, 'eval.log'), level=logging.INFO, force=True) seed_everything(args.seed) g = torch.Generator() g.manual_seed(args.seed) + + os.makedirs(args.output, exist_ok=True) - eval_model, preprocess = load_dreamsim_model(args) - nights_results = eval_nights(eval_model, preprocess, device, args) - bapps_results = eval_bapps(eval_model, preprocess, device, args) - things_results = eval_things(eval_model, preprocess, device, args) - df2_results = eval_df2(eval_model, preprocess, device, args) - - if "baseline_model" in args: + full_results = {} + if args.eval_checkpoint is not None: + eval_model, preprocess = load_dreamsim_model(args) + full_results['ckpt'] = full_eval(eval_model, preprocess, device, args) + if args.baseline_model is not None: baseline_model, baseline_preprocess = load_baseline_model(args) - nights_results = eval_nights(baseline_model, baseline_preprocess, device, args) - bapps_results = eval_bapps(baseline_model, baseline_preprocess, device, args) - things_results = eval_things(baseline_model, baseline_preprocess, device, args) - df2_results = eval_df2(baseline_model, baseline_preprocess, device, args) + full_results['baseline'] = full_eval(baseline_model, baseline_preprocess, device, args) + + tag = args.tag + "_" if len(args.tag) > 0 else "" + with open(os.path.join(args.output, f"{tag}eval_results.json"), "w") as f: + json.dump(full_results, f) + if __name__ == '__main__': args = parse_args() diff --git a/evaluation/score.py b/evaluation/score.py index 2dc729b..1c23338 100644 --- a/evaluation/score.py +++ b/evaluation/score.py @@ -41,6 +41,7 @@ def score_nights_dataset(model, test_loader, device): def score_things_dataset(model, test_loader, device): logging.info("Evaluating Things dataset.") count = 0 + total = 0 with torch.no_grad(): for i, (img_1, img_2, img_3) in tqdm(enumerate(test_loader), total=len(test_loader)): img_1, img_2, img_3 = img_1.to(device), img_2.to(device), img_3.to(device) @@ -53,9 +54,10 @@ def score_things_dataset(model, test_loader, device): le_2_3 = torch.le(dist_1_2, dist_2_3) count += sum(torch.logical_and(le_1_3, le_2_3)) + total += len(torch.logical_and(le_1_3, le_2_3)) count = count.detach().cpu().numpy() - accs = count / len(full_dataset) - print(f"Things accs: {str(accs)}") + total = total.detach().cpu().numpy() + accs = count / total return accs def score_bapps_dataset(model, test_loader, device): @@ -77,7 +79,6 @@ def score_bapps_dataset(model, test_loader, device): ps = torch.cat(ps, dim=0) scores = (d0s < d1s) * (1.0 - ps) + (d1s < d0s) * ps + (d1s == d0s) * 0.5 final_score = torch.mean(scores, dim=0) - print(f"BAPPS score: {str(final_score)}") return final_score def score_df2_dataset(model, train_loader, test_loader, gt_path, device): @@ -129,7 +130,6 @@ def extract_feats(model, dataloader): for k in ks: recall[k] = retrieved[k] / test_embeds.shape[0] - print(f"DF2 recall@k: {str(recall)}") return recall diff --git a/setup.py b/setup.py index 8485f1a..0178b7c 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ install_requires=[ "numpy", "open-clip-torch", - "peft==0.1.0", + "peft", "Pillow", "torch", "timm", From d243d99ac1c426e53828e01712d2b9bc846458e2 Mon Sep 17 00:00:00 2001 From: ssundaram21 Date: Tue, 30 Jul 2024 14:07:34 -0400 Subject: [PATCH 13/15] clean for release --- configs/eval.yaml | 19 ------------ configs/eval_baseline.yaml | 8 ----- configs/eval_checkpoint.yaml | 6 ---- configs/eval_dino.yaml | 19 ------------ configs/eval_ensemble.yaml | 22 +++++++------- configs/eval_open_clip.yaml | 19 ------------ configs/eval_single_clip.yaml | 17 +++++++++++ evaluation/eval_datasets.py | 52 --------------------------------- evaluation/eval_percep.py | 15 ++-------- evaluation/score.py | 55 ----------------------------------- 10 files changed, 29 insertions(+), 203 deletions(-) delete mode 100644 configs/eval.yaml delete mode 100644 configs/eval_baseline.yaml delete mode 100644 configs/eval_checkpoint.yaml delete mode 100644 configs/eval_dino.yaml delete mode 100644 configs/eval_open_clip.yaml create mode 100644 configs/eval_single_clip.yaml diff --git a/configs/eval.yaml b/configs/eval.yaml deleted file mode 100644 index 7f16b0c..0000000 --- a/configs/eval.yaml +++ /dev/null @@ -1,19 +0,0 @@ -tag: "clip" - -eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/checkpoints/clip_vitb32_lora/" -eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/config.yaml" -load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" - -baseline_model: "clip_vitb32" -baseline_feat_type: "cls" -baseline_stride: "32" - -nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" -bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" -things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" -things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" -df2_root: "/data/vision/phillipi/perception/data/df2_org3/" -df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" - -batch_size: 256 -num_workers: 10 \ No newline at end of file diff --git a/configs/eval_baseline.yaml b/configs/eval_baseline.yaml deleted file mode 100644 index ed202ba..0000000 --- a/configs/eval_baseline.yaml +++ /dev/null @@ -1,8 +0,0 @@ -seed: 1234 -baseline_model: dreamsim -baseline_feat_type: cls,embedding,embedding -baseline_stride: 16,16,16 -baseline_output_path: "outputs" -nights_root: ./dataset/nights -num_workers: 10 -batch_size: 16 \ No newline at end of file diff --git a/configs/eval_checkpoint.yaml b/configs/eval_checkpoint.yaml deleted file mode 100644 index ed5030a..0000000 --- a/configs/eval_checkpoint.yaml +++ /dev/null @@ -1,6 +0,0 @@ -seed: 1234 -eval_root: "output/experiment_dir/lightning_logs/version_0" -checkpoint_epoch: 7 -nights_root: ./dataset/nights -num_workers: 10 -batch_size: 16 \ No newline at end of file diff --git a/configs/eval_dino.yaml b/configs/eval_dino.yaml deleted file mode 100644 index acf0012..0000000 --- a/configs/eval_dino.yaml +++ /dev/null @@ -1,19 +0,0 @@ -tag: "dino" - -eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_dino_vitb16_cls_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/checkpoints/dino_vitb16_lora/" -eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_dino_vitb16_cls_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/config.yaml" -load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" - -baseline_model: "dino_vitb16" -baseline_feat_type: "cls" -baseline_stride: "16" - -nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" -bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" -things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" -things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" -df2_root: "/data/vision/phillipi/perception/data/df2_org3/" -df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" - -batch_size: 256 -num_workers: 10 \ No newline at end of file diff --git a/configs/eval_ensemble.yaml b/configs/eval_ensemble.yaml index 99217bd..beac463 100644 --- a/configs/eval_ensemble.yaml +++ b/configs/eval_ensemble.yaml @@ -1,19 +1,17 @@ tag: "open_clip" -eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_open_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/checkpoints/open_clip_vitb32_lora/" -eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_open_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_32.0_loradropout_0.2/lightning_logs/version_0/config.yaml" -load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" +eval_checkpoint: "/path-to-ckpt/lightning_logs/version_0/checkpoints/epoch-to-eval/" +eval_checkpoint_cfg: "/path-to-ckpt/lightning_logs/version_0/config.yaml" +load_dir: "./models" -baseline_model: "open_clip_vitb32" -baseline_feat_type: "embedding" -baseline_stride: "32" +baseline_model: "dino_vitb16,clip_vitb16,open_clip_vitb16" +baseline_feat_type: "cls,embedding,embedding" +baseline_stride: "16,16,16" -nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" -bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" -things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" -things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" -df2_root: "/data/vision/phillipi/perception/data/df2_org3/" -df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" +nights_root: "./data/nights" +bapps_root: "./data/2afc/val" +things_root: "./data/things/things_src_images" +things_file: "./data/things/things_valset.txt" batch_size: 256 num_workers: 10 \ No newline at end of file diff --git a/configs/eval_open_clip.yaml b/configs/eval_open_clip.yaml deleted file mode 100644 index b689077..0000000 --- a/configs/eval_open_clip.yaml +++ /dev/null @@ -1,19 +0,0 @@ -tag: "ensemble" - -eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_ensemble_dino_vitb16,clip_vitb16,open_clip_vitb16_cls,embedding,embedding_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_1.0_loradropout_0.3/lightning_logs/version_0/checkpoints/ensemble_lora/" -eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_ensemble_dino_vitb16,clip_vitb16,open_clip_vitb16_cls,embedding,embedding_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_1.0_loradropout_0.3/lightning_logs/version_0/config.yaml" -load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models" - -baseline_model: "dino_vitb16,clip_vitb16,open_clip_vitb16" -baseline_feat_type: "cls,embedding,embedding" -baseline_stride: "16,16,16" - -nights_root: "/vision-nfs/isola/projects/shobhita/data/nights" -bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val" -things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images" -things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt" -df2_root: "/data/vision/phillipi/perception/data/df2_org3/" -df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json" - -batch_size: 256 -num_workers: 10 \ No newline at end of file diff --git a/configs/eval_single_clip.yaml b/configs/eval_single_clip.yaml new file mode 100644 index 0000000..78ae4a1 --- /dev/null +++ b/configs/eval_single_clip.yaml @@ -0,0 +1,17 @@ +tag: "clip" + +eval_checkpoint: "/path-to-ckpt/lightning_logs/version_0/checkpoints/epoch-to-eval/" +eval_checkpoint_cfg: "/path-to-ckpt/lightning_logs/version_0/config.yaml" +load_dir: "./models" + +baseline_model: "clip_vitb32" +baseline_feat_type: "cls" +baseline_stride: "32" + +nights_root: "./data/nights" +bapps_root: "./data/2afc/val" +things_root: "./data/things/things_src_images" +things_file: "./data/things/things_valset.txt" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/evaluation/eval_datasets.py b/evaluation/eval_datasets.py index aa0f508..db0bd74 100644 --- a/evaluation/eval_datasets.py +++ b/evaluation/eval_datasets.py @@ -69,23 +69,6 @@ def __getitem__(self, idx): im_ref = self.preprocess_fn(Image.open(self.ref_paths[idx])) return im_ref, im_left, im_right, judge -class DF2Dataset(torch.utils.data.Dataset): - def __init__(self, root_dir, split: str, preprocess: str, load_size: int = 224, - interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): - - self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) - # self.preprocess_fn=preprocess - self.paths = get_paths(os.path.join(root_dir, split)) - - def __len__(self): - return len(self.paths) - - def __getitem__(self, idx): - im_path = self.paths[idx] - img = Image.open(im_path) - img = self.preprocess_fn(img) - return img, im_path - def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: @@ -97,38 +80,3 @@ def get_paths(path): for ext in IMAGE_EXTENSIONS: all_paths += glob.glob(os.path.join(path, f"**.{ext}")) return all_paths - -# class ImageDataset(torch.utils.data.Dataset): -# def __init__(self, root, class_to_idx, transform=None, ret_path=False): -# """ -# :param root: Dataset root. Should follow the structure class1/0.jpg...n.jpg, class2/0.jpg...n.jpg -# :param class_to_idx: dictionary mapping the classnames to integers. -# :param transform: -# :param ret_path: boolean indicating whether to return the image path or not (useful for KNN for plotting nearest neighbors) -# """ - -# self.transform = transform -# self.label_to_idx = class_to_idx - -# self.paths = [] -# self.labels = [] -# for cls in class_to_idx: -# cls_paths = get_paths(os.path.join(root, cls)) -# self.paths += cls_paths -# self.labels += [self.label_to_idx[cls] for _ in cls_paths] - -# self.ret_path = ret_path - -# def __len__(self): -# return len(self.paths) - -# def __getitem__(self, idx): -# im_path, label = self.paths[idx], self.labels[idx] -# img = pil_loader(im_path) - -# if self.transform is not None: -# img = self.transform(img) -# if not self.ret_path: -# return img, label -# else: -# return img, label, im_path diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py index 3be4cba..349e0b8 100644 --- a/evaluation/eval_percep.py +++ b/evaluation/eval_percep.py @@ -8,8 +8,8 @@ import logging import json from training.train import LightningPerceptualModel -from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset, score_df2_dataset -from evaluation.eval_datasets import ThingsDataset, BAPPSDataset, DF2Dataset +from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset +from evaluation.eval_datasets import ThingsDataset, BAPPSDataset from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio from DISTS_pytorch import DISTS from dreamsim import PerceptualModel @@ -114,15 +114,6 @@ def eval_things(model, preprocess, device, args): logging.info(f"THINGS (total 2AFC): {things_score}") return {"things_total": things_score} -def eval_df2(model, preprocess, device, args): - train_dataset = DF2Dataset(root_dir=args.df2_root, split="gallery", preprocess=preprocess) - test_dataset = DF2Dataset(root_dir=args.df2_root, split="customer", preprocess=preprocess) - train_loader_df2 = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True) - test_loader_df2 = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True) - df2_score = score_df2_dataset(model, train_loader_df2, test_loader_df2, args.df2_gt, device) - logging.info(f"DF2 (total recall@k): {str(recall)}") - return {"df2_total": df2_score} - def full_eval(eval_model, preprocess, device, args): results = {} if args.nights_root is not None: @@ -131,8 +122,6 @@ def full_eval(eval_model, preprocess, device, args): results['ckpt_bapps'] = bapps_results = eval_bapps(eval_model, preprocess, device, args) if args.things_root is not None: results['ckpt_things'] = eval_things(eval_model, preprocess, device, args) - if args.df2_root is not None: - results['ckpt_df2_root'] = eval_df2(eval_model, preprocess, device, args) return results def run(args, device): diff --git a/evaluation/score.py b/evaluation/score.py index 1c23338..bc0b8b9 100644 --- a/evaluation/score.py +++ b/evaluation/score.py @@ -56,7 +56,6 @@ def score_things_dataset(model, test_loader, device): count += sum(torch.logical_and(le_1_3, le_2_3)) total += len(torch.logical_and(le_1_3, le_2_3)) count = count.detach().cpu().numpy() - total = total.detach().cpu().numpy() accs = count / total return accs @@ -80,57 +79,3 @@ def score_bapps_dataset(model, test_loader, device): scores = (d0s < d1s) * (1.0 - ps) + (d1s < d0s) * ps + (d1s == d0s) * 0.5 final_score = torch.mean(scores, dim=0) return final_score - -def score_df2_dataset(model, train_loader, test_loader, gt_path, device): - - def extract_feats(model, dataloader): - embeds = [] - paths = [] - for im, path in tqdm(dataloader): - im = im.to(device) - paths.append(path) - with torch.no_grad(): - out = model.embed(im).squeeze() - embeds.append(out.to("cpu")) - embeds = torch.vstack(embeds).numpy() - paths = np.concatenate(paths) - return embeds, paths - - train_embeds, train_paths = extract_feats(model, train_loader) - train_embeds = torch.from_numpy(train_embeds).to('cuda') - test_embeds, test_paths = extract_feats(model, test_loader) - test_embeds = torch.from_numpy(test_embeds).to('cuda') - - with open(gt_path, "r") as f: - gt = json.load(f) - - ks = [1, 3, 5] - all_results = {} - - relevant = {k: 0 for k in ks} - retrieved = {k: 0 for k in ks} - recall = {k: 0 for k in ks} - - for i in tqdm(range(test_embeds.shape[0]), total=test_embeds.shape[0]): - sim = F.cosine_similarity(test_embeds[i, :], train_embeds, dim=-1) - ranks = torch.argsort(-sim).cpu() - - query_path = test_paths[i] - total_relevant = len(gt[query_path]) - gt_retrievals = gt[query_path] - for k in ks: - if k > 1: - k_retrieved = int(len([x for x in train_paths[ranks.cpu()[:k]] if x in gt_retrievals]) >0) - else: - k_retrieved = int(train_paths[ranks.cpu()[:k]] in gt_retrievals) - - relevant[k] += total_relevant - retrieved[k] += k_retrieved - - for k in ks: - recall[k] = retrieved[k] / test_embeds.shape[0] - - return recall - - - From becf40b0cb1df6c2f2cd68e436955f6dcb8265b0 Mon Sep 17 00:00:00 2001 From: ssundaram21 Date: Sun, 18 Aug 2024 15:14:48 -0400 Subject: [PATCH 14/15] further cleaning --- evaluation/eval_datasets.py | 20 +- evaluation/eval_percep.py | 20 +- evaluation/eval_util.py | 122 ------------- training/distill.py | 355 ------------------------------------ training/train.py | 10 +- util/precompute_sim.py | 69 ------- 6 files changed, 22 insertions(+), 574 deletions(-) delete mode 100644 evaluation/eval_util.py delete mode 100644 training/distill.py delete mode 100644 util/precompute_sim.py diff --git a/evaluation/eval_datasets.py b/evaluation/eval_datasets.py index db0bd74..1c410a3 100644 --- a/evaluation/eval_datasets.py +++ b/evaluation/eval_datasets.py @@ -1,18 +1,18 @@ +import os +import glob +import numpy as np +from PIL import Image from torch.utils.data import Dataset from util.utils import get_preprocess_fn from torchvision import transforms -import pandas as pd -import numpy as np -from PIL import Image -import os -from typing import Callable -import torch -import glob IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] - class ThingsDataset(Dataset): + """ + txt_file is expected to be the things_valset.txt list of triplets from the THINGS dataset. + root_dir is expected to be a directory of THINGS images. + """ def __init__(self, root_dir: str, txt_file: str, preprocess: str, load_size: int = 224, interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): with open(txt_file, "r") as f: @@ -37,10 +37,12 @@ def __getitem__(self, idx): return im_1, im_2, im_3 - class BAPPSDataset(Dataset): def __init__(self, root_dir: str, preprocess: str, load_size: int = 224, interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + """ + root_dir is expected to be the default validation folder of the BAPPS dataset. + """ data_types = ["cnn", "traditional", "color", "deblur", "superres", "frameinterp"] self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py index 349e0b8..fb5f50e 100644 --- a/evaluation/eval_percep.py +++ b/evaluation/eval_percep.py @@ -1,24 +1,16 @@ -from pytorch_lightning import seed_everything -import torch -from dataset.dataset import TwoAFCDataset -from util.utils import get_preprocess -from torch.utils.data import DataLoader import os import yaml import logging import json +import torch +import configargparse +from torch.utils.data import DataLoader +from pytorch_lightning import seed_everything +from dreamsim import PerceptualModel +from dataset.dataset import TwoAFCDataset from training.train import LightningPerceptualModel from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset from evaluation.eval_datasets import ThingsDataset, BAPPSDataset -from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio -from DISTS_pytorch import DISTS -from dreamsim import PerceptualModel -from tqdm import tqdm -import pickle -import configargparse -from dreamsim import dreamsim -import clip -from torchvision import transforms log = logging.getLogger("lightning.pytorch") log.propagate = False diff --git a/evaluation/eval_util.py b/evaluation/eval_util.py deleted file mode 100644 index 2ff4167..0000000 --- a/evaluation/eval_util.py +++ /dev/null @@ -1,122 +0,0 @@ -from torchvision import transforms -import glob -import os -from scripts.util import rescale - -IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] - -norms = { - "dino": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - "mae": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - "clip": transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - "open_clip": transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - "synclr": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - "resnet": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), -} - -dreamsim_transform = transforms.Compose([ - transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor() - ]) - -dino_transform = transforms.Compose([ - transforms.Resize(256, interpolation=3), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['dino'], - ]) - -dinov2_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['dino'], - ]) - -mae_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['mae'], - ]) - -simclrv2_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - ]) - -synclr_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['synclr'], - ]) - -clip_transform = transforms.Compose([ - transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['clip'], -]) - -# https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html -resnet_transform = transforms.Compose([ - transforms.Resize(232, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - rescale, - norms['resnet'], -]) - -open_clip_transform = clip_transform - -vanilla_transform = transforms.Compose([ - transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor() - ]) - -def get_val_transform(model_type): - if "dino" in model_type: - return dino_transform - elif "mae" in model_type: - return mae_transform - elif "clip" in model_type: - return clip_transform - elif "open_clip" in model_type: - return open_clip_transform - else: - return vanilla_transform - - -def get_train_transform(model_type): - if "mae" in model_type: - norm = norms["mae"] - elif "clip" in model_type: - norm = norms["clip"] - elif "open_clip" in model_type: - norm = norms["open_clip"] - else: - norm = norms["dino"] - - return transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norm, - ]) - - -def get_paths(path): - all_paths = [] - for ext in IMAGE_EXTENSIONS: - all_paths += glob.glob(os.path.join(path, f"**.{ext}")) - return all_paths diff --git a/training/distill.py b/training/distill.py deleted file mode 100644 index 2d0656a..0000000 --- a/training/distill.py +++ /dev/null @@ -1,355 +0,0 @@ -import logging -import yaml -import pytorch_lightning as pl -from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.callbacks import ModelCheckpoint -from util.train_utils import Mean, HingeLoss, seed_worker -from util.utils import get_preprocess -from dataset.dataset import TwoAFCDataset -from torch.utils.data import DataLoader -import torch -from peft import get_peft_model, LoraConfig, PeftModel -from dreamsim import PerceptualModel -from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig -import os -import configargparse -from tqdm import tqdm -import pickle as pkl -import torch.nn as nn -from sklearn.decomposition import PCA - -torch.autograd.set_detect_anomaly(True) -def parse_args(): - parser = configargparse.ArgumentParser() - parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') - - ## Run options - parser.add_argument('--seed', type=int, default=1234) - parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)') - parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs') - parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints') - - ## Model options - parser.add_argument('--model_type', type=str, default='dino_vitb16', - help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated' - 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' - 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') - parser.add_argument('--feat_type', type=str, default='cls', - help='What type of feature to extract from the model. If finetuning an ensemble, pass a ' - 'comma-separated list of features (same length as model_type). Accepted feature types: ' - '[cls, embedding, last_layer].') - parser.add_argument('--stride', type=str, default='16', - help='Stride of first convolution layer the model (should match patch size). If finetuning' - 'an ensemble, pass a comma-separated list (same length as model_type).') - parser.add_argument('--use_lora', type=bool, default=False, - help='Whether to train with LoRA finetuning [True] or with an MLP head [False].') - parser.add_argument('--hidden_size', type=int, default=1, help='Size of the MLP hidden layer.') - - ## Dataset options - parser.add_argument('--dataset_root', type=str, default="./dataset/nights", help='path to training dataset.') - parser.add_argument('--num_workers', type=int, default=4) - - ## Training options - parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training.') - parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay for training.') - parser.add_argument('--batch_size', type=int, default=4, help='Dataset batch size.') - parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.') - parser.add_argument('--margin', default=0.01, type=float, help='Margin for hinge loss') - - ## LoRA-specific options - parser.add_argument('--lora_r', type=int, default=8, help='LoRA attention dimension') - parser.add_argument('--lora_alpha', type=float, default=0.1, help='Alpha for attention scaling') - parser.add_argument('--lora_dropout', type=float, default=0.1, help='Dropout probability for LoRA layers') - - return parser.parse_args() - - -class LightningPerceptualModel(pl.LightningModule): - def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1, - lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16, - lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1, - load_dir: str = "./models", device: str = "cuda", - **kwargs): - super().__init__() - self.save_hyperparameters() - - self.feat_type = feat_type - self.model_type = model_type - self.stride = stride - self.hidden_size = hidden_size - self.lr = lr - self.use_lora = use_lora - self.margin = margin - self.weight_decay = weight_decay - self.lora_r = lora_r - self.lora_alpha = lora_alpha - self.lora_dropout = lora_dropout - self.train_data_len = train_data_len - - self.started = False - self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)} - self.__reset_val_metrics() - - self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride, - hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir, - device=device) - if self.use_lora: - self.__prep_lora_model() - else: - self.__prep_linear_model() - - pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) - pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) - print(pytorch_total_params) - print(pytorch_total_trainable_params) - - self.criterion = nn.L1Loss() - self.teacher = PerceptualModel(feat_type='cls,embedding,embedding', - model_type='dino_vitb16,clip_vitb16,open_clip_vitb16', - stride='16,16,16', - hidden_size=self.hidden_size, lora=False, load_dir=load_dir, - device=device) - - self.epoch_loss_train = 0.0 - self.train_num_correct = 0.0 - - self.automatic_optimization = False - - with open('precomputed_sims.pkl', 'rb') as f: - self.sims = pkl.load(f) - - # with open('precomputed_embeds.pkl', 'rb') as f: - # self.pca = pkl.load(f) - - def forward(self, img_ref, img_0, img_1): - _, embed_0, dist_0 = self.perceptual_model(img_ref, img_0) - embed_ref, embed_1, dist_1 = self.perceptual_model(img_ref, img_1) - return embed_ref, embed_0, embed_1, dist_0, dist_1 - - def training_step(self, batch, batch_idx): - img_ref, img_0, img_1, _, idx = batch - embed_ref, embed_0, embed_1, dist_0, dist_1 = self.forward(img_ref, img_0, img_1) - - # with torch.no_grad(): - # target_embed_ref = self.teacher.embed(img_ref) - # target_embed_0 = self.teacher.embed(img_0) - # target_embed_1 = self.teacher.embed(img_1) - # - # target_embed_ref = torch.tensor(self.pca.transform(target_embed_ref.cpu()), device=img_ref.device).float() - # target_embed_0 = torch.tensor(self.pca.transform(target_embed_0.cpu()), device=img_ref.device).float() - # target_embed_1 = torch.tensor(self.pca.transform(target_embed_1.cpu()), device=img_ref.device).float() - - target_0 = [self.sims['train'][i.item()][0] for i in idx] - target_dist_0 = torch.tensor(target_0, device=img_ref.device) - - target_1 = [self.sims['train'][i.item()][1] for i in idx] - target_dist_1 = torch.tensor(target_1, device=img_ref.device) - - opt = self.optimizers() - opt.zero_grad() - loss_0 = self.criterion(dist_0, target_dist_0) #/ target_dist_0.shape[0] - # loss_ref = self.criterion(embed_ref, target_embed_ref).mean() - # self.manual_backward(loss_0) - - loss_1 = self.criterion(dist_1, target_dist_1) #/ target_dist_1.shape[0] - # loss_0 = self.criterion(embed_0, target_embed_0).mean().float() - # self.manual_backward(loss_1) - - # loss_1 = self.criterion(embed_1, target_embed_1).mean().float() - # self.manual_backward(loss_1) - loss = (loss_0 + loss_1).mean() - self.manual_backward(loss) - opt.step() - - target = torch.lt(target_dist_1, target_dist_0) - decisions = torch.lt(dist_1, dist_0) - - # self.epoch_loss_train += loss_ref - self.epoch_loss_train += loss_0 - self.epoch_loss_train += loss_1 - self.train_num_correct += ((target >= 0.5) == decisions).sum() - return loss_0 + loss_1 - - def validation_step(self, batch, batch_idx): - img_ref, img_0, img_1, _, idx = batch - embed_ref, embed_0, embed_1, dist_0, dist_1 = self.forward(img_ref, img_0, img_1) - - # with torch.no_grad(): - # target_embed_ref = self.teacher.embed(img_ref) - # target_embed_0 = self.teacher.embed(img_0) - # target_embed_1 = self.teacher.embed(img_1) - # - # target_embed_ref = torch.tensor(self.pca.transform(target_embed_ref.cpu()), device=img_ref.device) - # target_embed_0 = torch.tensor(self.pca.transform(target_embed_0.cpu()), device=img_ref.device) - # target_embed_1 = torch.tensor(self.pca.transform(target_embed_1.cpu()), device=img_ref.device) - - target_0 = [self.sims['val'][i.item()][0] for i in idx] - target_1 = [self.sims['val'][i.item()][1] for i in idx] - - target_dist_0 = torch.tensor(target_0, device=img_ref.device) - target_dist_1 = torch.tensor(target_1, device=img_ref.device) - - target = torch.lt(target_dist_1, target_dist_0) - decisions = torch.lt(dist_1, dist_0) - - loss = self.criterion(dist_0, target_dist_0) - loss += self.criterion(dist_1, target_dist_1) - # loss = self.criterion(embed_ref, target_embed_ref).float() - # loss += self.criterion(embed_0, target_embed_0).float() - # loss += self.criterion(embed_1, target_embed_1).float() - loss = loss.mean() - - val_num_correct = ((target >= 0.5) == decisions).sum() - self.val_metrics['loss'].update(loss, target.shape[0]) - self.val_metrics['score'].update(val_num_correct, target.shape[0]) - return loss - - def on_train_epoch_start(self): - self.epoch_loss_train = 0.0 - self.train_num_correct = 0.0 - self.started = True - - def on_train_epoch_end(self): - epoch = self.current_epoch + 1 if self.started else 0 - self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch) - self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch) - if self.use_lora: - self.__save_lora_weights() - - def on_train_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.train() - - def on_validation_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.eval() - - def on_validation_epoch_start(self): - self.__reset_val_metrics() - - def on_validation_epoch_end(self): - epoch = self.current_epoch + 1 if self.started else 0 - score = self.val_metrics['score'].compute() - loss = self.val_metrics['loss'].compute() - - self.log(f'val_acc_ckpt', score, logger=False) - self.log(f'val_loss_ckpt', loss, logger=False) - # log for tensorboard - self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch) - self.logger.experiment.add_scalar(f'val_loss/', loss, epoch) - - return score - - def configure_optimizers(self): - params = list(self.perceptual_model.parameters()) - for extractor in self.perceptual_model.extractor_list: - params += list(extractor.model.parameters()) - for extractor, feat_type in zip(self.perceptual_model.extractor_list, self.perceptual_model.feat_type_list): - if feat_type == 'embedding': - params += [extractor.proj] - optimizer = torch.optim.Adam(params, lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) - return [optimizer] - - def load_lora_weights(self, checkpoint_root, epoch_load): - for extractor in self.perceptual_model.extractor_list: - load_dir = os.path.join(checkpoint_root, - f'epoch_{epoch_load}_{extractor.model_type}') - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device) - - def __reset_val_metrics(self): - for k, v in self.val_metrics.items(): - v.reset() - - def __prep_lora_model(self): - for extractor in self.perceptual_model.extractor_list: - config = LoraConfig( - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout, - bias='none', - target_modules=['qkv'] - ) - extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()), - config).to(extractor.device) - extractor.model = extractor_model - - def __prep_linear_model(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.requires_grad_(False) - if self.feat_type == "embedding": - extractor.proj.requires_grad_(False) - self.perceptual_model.mlp.requires_grad_(True) - - def __save_lora_weights(self): - for extractor in self.perceptual_model.extractor_list: - save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, - f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') - extractor.model.save_pretrained(save_dir, safe_serialization=False) - adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) - new_adapters_weights = dict() - - for k, v in adapters_weights.items(): - new_k = 'base_model.model.' + k - new_adapters_weights[new_k] = v - torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin')) - - -def run(args, device): - tag = args.tag if len(args.tag) > 0 else "" - training_method = "lora" if args.use_lora else "mlp" - exp_dir = os.path.join(args.log_dir, - f'{tag}_{str(args.model_type)}_{str(args.feat_type)}_{str(training_method)}_' + - f'lr_{str(args.lr)}_batchsize_{str(args.batch_size)}_wd_{str(args.weight_decay)}' - f'_hiddensize_{str(args.hidden_size)}_margin_{str(args.margin)}' - ) - if args.use_lora: - exp_dir += f'_lorar_{str(args.lora_r)}_loraalpha_{str(args.lora_alpha)}_loradropout_{str(args.lora_dropout)}' - - seed_everything(args.seed) - g = torch.Generator() - g.manual_seed(args.seed) - - train_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="train", preprocess=get_preprocess(args.model_type)) - val_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="val", preprocess=get_preprocess(args.model_type)) - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, - worker_init_fn=seed_worker, generator=g) - val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) - - logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) - trainer = Trainer(devices=1, - accelerator='gpu', - log_every_n_steps=10, - logger=logger, - max_epochs=args.epochs, - default_root_dir=exp_dir, - callbacks=ModelCheckpoint(monitor='val_loss_ckpt', - save_top_k=-1, - save_last=True, - filename='{epoch:02d}', - mode='max'), - num_sanity_val_steps=0, - ) - checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}') - os.makedirs(checkpoint_root, exist_ok=True) - with open(os.path.join(checkpoint_root, 'config.yaml'), 'w') as f: - yaml.dump(args, f) - - logging.basicConfig(filename=os.path.join(checkpoint_root, 'exp.log'), level=logging.INFO, force=True) - logging.info("Arguments: ", vars(args)) - - model = LightningPerceptualModel(device=device, train_data_len=len(train_dataset), **vars(args)) - - logging.info("Validating before training") - trainer.validate(model, val_loader) - logging.info("Training") - trainer.fit(model, train_loader, val_loader) - - print("Done :)") - - -if __name__ == '__main__': - args = parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" - run(args, device) diff --git a/training/train.py b/training/train.py index 7814edb..2658937 100644 --- a/training/train.py +++ b/training/train.py @@ -1,18 +1,18 @@ +import os +import configargparse import logging import yaml +import torch import pytorch_lightning as pl +from torch.utils.data import DataLoader +from peft import get_peft_model, LoraConfig, PeftModel from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from util.train_utils import Mean, HingeLoss, seed_worker from util.utils import get_preprocess from dataset.dataset import TwoAFCDataset -from torch.utils.data import DataLoader -import torch -from peft import get_peft_model, LoraConfig, PeftModel from dreamsim import PerceptualModel -import os -import configargparse def parse_args(): diff --git a/util/precompute_sim.py b/util/precompute_sim.py deleted file mode 100644 index 9e162eb..0000000 --- a/util/precompute_sim.py +++ /dev/null @@ -1,69 +0,0 @@ -from PIL import Image -from lightning_fabric import seed_everything -from torch.utils.data import DataLoader - -from dataset.dataset import TwoAFCDataset -from dreamsim import dreamsim -from torchvision import transforms -import torch -import os -from tqdm import tqdm - -from util.train_utils import seed_worker -from util.utils import get_preprocess -import numpy as np -import pickle as pkl -from sklearn.decomposition import PCA - -seed = 1234 -dataset_root = './dataset/nights' -model_type = 'dino_vitb16,clip_vitb16,open_clip_vitb16' -num_workers = 8 -batch_size = 32 - -seed_everything(seed) -g = torch.Generator() -g.manual_seed(seed) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -model, preprocess = dreamsim(pretrained=True, device=device, cache_dir='./models_new') - -train_dataset = TwoAFCDataset(root_dir=dataset_root, split="train", preprocess=get_preprocess(model_type)) -val_dataset = TwoAFCDataset(root_dir=dataset_root, split="val", preprocess=get_preprocess(model_type)) -test_dataset = TwoAFCDataset(root_dir=dataset_root, split="test", preprocess=get_preprocess(model_type)) -train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, - worker_init_fn=seed_worker, generator=g) -val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) -test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) - -data = {'train': {}, 'val': {}, 'test': {}} -all_embeds = [] - -with torch.no_grad(): - for split, loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)]: - for img_ref, img_left, img_right, p, id in tqdm(loader): - img_ref = img_ref.to(device) - img_left = img_left.to(device) - img_right = img_right.to(device) - - embed_ref, embed_0, d0 = model(img_ref, img_left) - _, embed_1, d1 = model(img_ref, img_right) - # - # if split == 'train': - # all_embeds.append(embed_ref) - # all_embeds.append(embed_0) - # all_embeds.append(embed_1) - - for i in range(len(id)): - curr_id = id[i].item() - data[split][curr_id] = [d0[i].item(), d1[i].item()] - -# all_embeds = torch.cat(all_embeds).cpu() -# principal = PCA(n_components=512) -# principal.fit(all_embeds) - -with open('precomputed_sims.pkl', 'wb') as f: - pkl.dump(data, f) - -# with open('precomputed_embeds.pkl', 'wb') as f: -# pkl.dump(principal, f) \ No newline at end of file From c70f64a324598370a39f287faea7b1f93aeade89 Mon Sep 17 00:00:00 2001 From: ssundaram21 Date: Sun, 18 Aug 2024 17:24:26 -0400 Subject: [PATCH 15/15] release prep --- dreamsim/config.py | 1 + .../feature_extraction/load_synclr_as_dino.py | 16 ---------------- setup.py | 2 +- training/download_models.sh | 1 + 4 files changed, 3 insertions(+), 17 deletions(-) delete mode 100644 dreamsim/feature_extraction/load_synclr_as_dino.py diff --git a/dreamsim/config.py b/dreamsim/config.py index ede32e0..c017975 100644 --- a/dreamsim/config.py +++ b/dreamsim/config.py @@ -28,6 +28,7 @@ "img_size": 224 } +# UPDATE dreamsim_weights = { "ensemble": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/dreamsim_checkpoint.zip", "dino_vitb16": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.2/dreamsim_dino_vitb16_checkpoint.zip", diff --git a/dreamsim/feature_extraction/load_synclr_as_dino.py b/dreamsim/feature_extraction/load_synclr_as_dino.py deleted file mode 100644 index 85c6477..0000000 --- a/dreamsim/feature_extraction/load_synclr_as_dino.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -from .vision_transformer import vit_base, VisionTransformer -import os - - -def load_synclr_as_dino(patch_size, load_dir="./models", l14=False): - sd = torch.load(os.path.join(load_dir, f'synclr_vit_b_{patch_size}.pth'))['model'] - dino_vit = vit_base(patch_size=patch_size) - new_sd = dict() - - for k, v in sd.items(): - new_key = k[14:] # strip "module.visual" from key - new_sd[new_key] = v - - dino_vit.load_state_dict(new_sd) - return dino_vit diff --git a/setup.py b/setup.py index 0178b7c..e0fe0e9 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="dreamsim", - version="0.1.3", + version="0.2.0", description="DreamSim similarity metric", long_description=long_description, long_description_content_type="text/markdown", diff --git a/training/download_models.sh b/training/download_models.sh index a0e5d13..f7d951d 100644 --- a/training/download_models.sh +++ b/training/download_models.sh @@ -2,6 +2,7 @@ mkdir -p ./models cd models +## UDPATE wget -O dreamsim_checkpoint.zip https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/dreamsim_checkpoint.zip wget -O clip_vitb32_pretrain.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/clip_vitb32_pretrain.pth.tar wget -O clipl14_as_dino_vitl.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/clip_vitl14_pretrain.pth.tar