From becf40b0cb1df6c2f2cd68e436955f6dcb8265b0 Mon Sep 17 00:00:00 2001 From: ssundaram21 Date: Sun, 18 Aug 2024 15:14:48 -0400 Subject: [PATCH] further cleaning --- evaluation/eval_datasets.py | 20 +- evaluation/eval_percep.py | 20 +- evaluation/eval_util.py | 122 ------------- training/distill.py | 355 ------------------------------------ training/train.py | 10 +- util/precompute_sim.py | 69 ------- 6 files changed, 22 insertions(+), 574 deletions(-) delete mode 100644 evaluation/eval_util.py delete mode 100644 training/distill.py delete mode 100644 util/precompute_sim.py diff --git a/evaluation/eval_datasets.py b/evaluation/eval_datasets.py index db0bd74..1c410a3 100644 --- a/evaluation/eval_datasets.py +++ b/evaluation/eval_datasets.py @@ -1,18 +1,18 @@ +import os +import glob +import numpy as np +from PIL import Image from torch.utils.data import Dataset from util.utils import get_preprocess_fn from torchvision import transforms -import pandas as pd -import numpy as np -from PIL import Image -import os -from typing import Callable -import torch -import glob IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] - class ThingsDataset(Dataset): + """ + txt_file is expected to be the things_valset.txt list of triplets from the THINGS dataset. + root_dir is expected to be a directory of THINGS images. + """ def __init__(self, root_dir: str, txt_file: str, preprocess: str, load_size: int = 224, interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): with open(txt_file, "r") as f: @@ -37,10 +37,12 @@ def __getitem__(self, idx): return im_1, im_2, im_3 - class BAPPSDataset(Dataset): def __init__(self, root_dir: str, preprocess: str, load_size: int = 224, interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + """ + root_dir is expected to be the default validation folder of the BAPPS dataset. + """ data_types = ["cnn", "traditional", "color", "deblur", "superres", "frameinterp"] self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py index 349e0b8..fb5f50e 100644 --- a/evaluation/eval_percep.py +++ b/evaluation/eval_percep.py @@ -1,24 +1,16 @@ -from pytorch_lightning import seed_everything -import torch -from dataset.dataset import TwoAFCDataset -from util.utils import get_preprocess -from torch.utils.data import DataLoader import os import yaml import logging import json +import torch +import configargparse +from torch.utils.data import DataLoader +from pytorch_lightning import seed_everything +from dreamsim import PerceptualModel +from dataset.dataset import TwoAFCDataset from training.train import LightningPerceptualModel from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset from evaluation.eval_datasets import ThingsDataset, BAPPSDataset -from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio -from DISTS_pytorch import DISTS -from dreamsim import PerceptualModel -from tqdm import tqdm -import pickle -import configargparse -from dreamsim import dreamsim -import clip -from torchvision import transforms log = logging.getLogger("lightning.pytorch") log.propagate = False diff --git a/evaluation/eval_util.py b/evaluation/eval_util.py deleted file mode 100644 index 2ff4167..0000000 --- a/evaluation/eval_util.py +++ /dev/null @@ -1,122 +0,0 @@ -from torchvision import transforms -import glob -import os -from scripts.util import rescale - -IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] - -norms = { - "dino": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - "mae": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - "clip": transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - "open_clip": transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - "synclr": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - "resnet": transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), -} - -dreamsim_transform = transforms.Compose([ - transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor() - ]) - -dino_transform = transforms.Compose([ - transforms.Resize(256, interpolation=3), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['dino'], - ]) - -dinov2_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['dino'], - ]) - -mae_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['mae'], - ]) - -simclrv2_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - ]) - -synclr_transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['synclr'], - ]) - -clip_transform = transforms.Compose([ - transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norms['clip'], -]) - -# https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html -resnet_transform = transforms.Compose([ - transforms.Resize(232, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(224), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - rescale, - norms['resnet'], -]) - -open_clip_transform = clip_transform - -vanilla_transform = transforms.Compose([ - transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), - transforms.ToTensor() - ]) - -def get_val_transform(model_type): - if "dino" in model_type: - return dino_transform - elif "mae" in model_type: - return mae_transform - elif "clip" in model_type: - return clip_transform - elif "open_clip" in model_type: - return open_clip_transform - else: - return vanilla_transform - - -def get_train_transform(model_type): - if "mae" in model_type: - norm = norms["mae"] - elif "clip" in model_type: - norm = norms["clip"] - elif "open_clip" in model_type: - norm = norms["open_clip"] - else: - norm = norms["dino"] - - return transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - lambda x: x.convert('RGB'), - transforms.ToTensor(), - norm, - ]) - - -def get_paths(path): - all_paths = [] - for ext in IMAGE_EXTENSIONS: - all_paths += glob.glob(os.path.join(path, f"**.{ext}")) - return all_paths diff --git a/training/distill.py b/training/distill.py deleted file mode 100644 index 2d0656a..0000000 --- a/training/distill.py +++ /dev/null @@ -1,355 +0,0 @@ -import logging -import yaml -import pytorch_lightning as pl -from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.callbacks import ModelCheckpoint -from util.train_utils import Mean, HingeLoss, seed_worker -from util.utils import get_preprocess -from dataset.dataset import TwoAFCDataset -from torch.utils.data import DataLoader -import torch -from peft import get_peft_model, LoraConfig, PeftModel -from dreamsim import PerceptualModel -from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig -import os -import configargparse -from tqdm import tqdm -import pickle as pkl -import torch.nn as nn -from sklearn.decomposition import PCA - -torch.autograd.set_detect_anomaly(True) -def parse_args(): - parser = configargparse.ArgumentParser() - parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') - - ## Run options - parser.add_argument('--seed', type=int, default=1234) - parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)') - parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs') - parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints') - - ## Model options - parser.add_argument('--model_type', type=str, default='dino_vitb16', - help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated' - 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' - 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') - parser.add_argument('--feat_type', type=str, default='cls', - help='What type of feature to extract from the model. If finetuning an ensemble, pass a ' - 'comma-separated list of features (same length as model_type). Accepted feature types: ' - '[cls, embedding, last_layer].') - parser.add_argument('--stride', type=str, default='16', - help='Stride of first convolution layer the model (should match patch size). If finetuning' - 'an ensemble, pass a comma-separated list (same length as model_type).') - parser.add_argument('--use_lora', type=bool, default=False, - help='Whether to train with LoRA finetuning [True] or with an MLP head [False].') - parser.add_argument('--hidden_size', type=int, default=1, help='Size of the MLP hidden layer.') - - ## Dataset options - parser.add_argument('--dataset_root', type=str, default="./dataset/nights", help='path to training dataset.') - parser.add_argument('--num_workers', type=int, default=4) - - ## Training options - parser.add_argument('--lr', type=float, default=0.001, help='Learning rate for training.') - parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay for training.') - parser.add_argument('--batch_size', type=int, default=4, help='Dataset batch size.') - parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs.') - parser.add_argument('--margin', default=0.01, type=float, help='Margin for hinge loss') - - ## LoRA-specific options - parser.add_argument('--lora_r', type=int, default=8, help='LoRA attention dimension') - parser.add_argument('--lora_alpha', type=float, default=0.1, help='Alpha for attention scaling') - parser.add_argument('--lora_dropout', type=float, default=0.1, help='Dropout probability for LoRA layers') - - return parser.parse_args() - - -class LightningPerceptualModel(pl.LightningModule): - def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1, - lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16, - lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1, - load_dir: str = "./models", device: str = "cuda", - **kwargs): - super().__init__() - self.save_hyperparameters() - - self.feat_type = feat_type - self.model_type = model_type - self.stride = stride - self.hidden_size = hidden_size - self.lr = lr - self.use_lora = use_lora - self.margin = margin - self.weight_decay = weight_decay - self.lora_r = lora_r - self.lora_alpha = lora_alpha - self.lora_dropout = lora_dropout - self.train_data_len = train_data_len - - self.started = False - self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)} - self.__reset_val_metrics() - - self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride, - hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir, - device=device) - if self.use_lora: - self.__prep_lora_model() - else: - self.__prep_linear_model() - - pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) - pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) - print(pytorch_total_params) - print(pytorch_total_trainable_params) - - self.criterion = nn.L1Loss() - self.teacher = PerceptualModel(feat_type='cls,embedding,embedding', - model_type='dino_vitb16,clip_vitb16,open_clip_vitb16', - stride='16,16,16', - hidden_size=self.hidden_size, lora=False, load_dir=load_dir, - device=device) - - self.epoch_loss_train = 0.0 - self.train_num_correct = 0.0 - - self.automatic_optimization = False - - with open('precomputed_sims.pkl', 'rb') as f: - self.sims = pkl.load(f) - - # with open('precomputed_embeds.pkl', 'rb') as f: - # self.pca = pkl.load(f) - - def forward(self, img_ref, img_0, img_1): - _, embed_0, dist_0 = self.perceptual_model(img_ref, img_0) - embed_ref, embed_1, dist_1 = self.perceptual_model(img_ref, img_1) - return embed_ref, embed_0, embed_1, dist_0, dist_1 - - def training_step(self, batch, batch_idx): - img_ref, img_0, img_1, _, idx = batch - embed_ref, embed_0, embed_1, dist_0, dist_1 = self.forward(img_ref, img_0, img_1) - - # with torch.no_grad(): - # target_embed_ref = self.teacher.embed(img_ref) - # target_embed_0 = self.teacher.embed(img_0) - # target_embed_1 = self.teacher.embed(img_1) - # - # target_embed_ref = torch.tensor(self.pca.transform(target_embed_ref.cpu()), device=img_ref.device).float() - # target_embed_0 = torch.tensor(self.pca.transform(target_embed_0.cpu()), device=img_ref.device).float() - # target_embed_1 = torch.tensor(self.pca.transform(target_embed_1.cpu()), device=img_ref.device).float() - - target_0 = [self.sims['train'][i.item()][0] for i in idx] - target_dist_0 = torch.tensor(target_0, device=img_ref.device) - - target_1 = [self.sims['train'][i.item()][1] for i in idx] - target_dist_1 = torch.tensor(target_1, device=img_ref.device) - - opt = self.optimizers() - opt.zero_grad() - loss_0 = self.criterion(dist_0, target_dist_0) #/ target_dist_0.shape[0] - # loss_ref = self.criterion(embed_ref, target_embed_ref).mean() - # self.manual_backward(loss_0) - - loss_1 = self.criterion(dist_1, target_dist_1) #/ target_dist_1.shape[0] - # loss_0 = self.criterion(embed_0, target_embed_0).mean().float() - # self.manual_backward(loss_1) - - # loss_1 = self.criterion(embed_1, target_embed_1).mean().float() - # self.manual_backward(loss_1) - loss = (loss_0 + loss_1).mean() - self.manual_backward(loss) - opt.step() - - target = torch.lt(target_dist_1, target_dist_0) - decisions = torch.lt(dist_1, dist_0) - - # self.epoch_loss_train += loss_ref - self.epoch_loss_train += loss_0 - self.epoch_loss_train += loss_1 - self.train_num_correct += ((target >= 0.5) == decisions).sum() - return loss_0 + loss_1 - - def validation_step(self, batch, batch_idx): - img_ref, img_0, img_1, _, idx = batch - embed_ref, embed_0, embed_1, dist_0, dist_1 = self.forward(img_ref, img_0, img_1) - - # with torch.no_grad(): - # target_embed_ref = self.teacher.embed(img_ref) - # target_embed_0 = self.teacher.embed(img_0) - # target_embed_1 = self.teacher.embed(img_1) - # - # target_embed_ref = torch.tensor(self.pca.transform(target_embed_ref.cpu()), device=img_ref.device) - # target_embed_0 = torch.tensor(self.pca.transform(target_embed_0.cpu()), device=img_ref.device) - # target_embed_1 = torch.tensor(self.pca.transform(target_embed_1.cpu()), device=img_ref.device) - - target_0 = [self.sims['val'][i.item()][0] for i in idx] - target_1 = [self.sims['val'][i.item()][1] for i in idx] - - target_dist_0 = torch.tensor(target_0, device=img_ref.device) - target_dist_1 = torch.tensor(target_1, device=img_ref.device) - - target = torch.lt(target_dist_1, target_dist_0) - decisions = torch.lt(dist_1, dist_0) - - loss = self.criterion(dist_0, target_dist_0) - loss += self.criterion(dist_1, target_dist_1) - # loss = self.criterion(embed_ref, target_embed_ref).float() - # loss += self.criterion(embed_0, target_embed_0).float() - # loss += self.criterion(embed_1, target_embed_1).float() - loss = loss.mean() - - val_num_correct = ((target >= 0.5) == decisions).sum() - self.val_metrics['loss'].update(loss, target.shape[0]) - self.val_metrics['score'].update(val_num_correct, target.shape[0]) - return loss - - def on_train_epoch_start(self): - self.epoch_loss_train = 0.0 - self.train_num_correct = 0.0 - self.started = True - - def on_train_epoch_end(self): - epoch = self.current_epoch + 1 if self.started else 0 - self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch) - self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch) - if self.use_lora: - self.__save_lora_weights() - - def on_train_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.train() - - def on_validation_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.eval() - - def on_validation_epoch_start(self): - self.__reset_val_metrics() - - def on_validation_epoch_end(self): - epoch = self.current_epoch + 1 if self.started else 0 - score = self.val_metrics['score'].compute() - loss = self.val_metrics['loss'].compute() - - self.log(f'val_acc_ckpt', score, logger=False) - self.log(f'val_loss_ckpt', loss, logger=False) - # log for tensorboard - self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch) - self.logger.experiment.add_scalar(f'val_loss/', loss, epoch) - - return score - - def configure_optimizers(self): - params = list(self.perceptual_model.parameters()) - for extractor in self.perceptual_model.extractor_list: - params += list(extractor.model.parameters()) - for extractor, feat_type in zip(self.perceptual_model.extractor_list, self.perceptual_model.feat_type_list): - if feat_type == 'embedding': - params += [extractor.proj] - optimizer = torch.optim.Adam(params, lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) - return [optimizer] - - def load_lora_weights(self, checkpoint_root, epoch_load): - for extractor in self.perceptual_model.extractor_list: - load_dir = os.path.join(checkpoint_root, - f'epoch_{epoch_load}_{extractor.model_type}') - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device) - - def __reset_val_metrics(self): - for k, v in self.val_metrics.items(): - v.reset() - - def __prep_lora_model(self): - for extractor in self.perceptual_model.extractor_list: - config = LoraConfig( - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout, - bias='none', - target_modules=['qkv'] - ) - extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()), - config).to(extractor.device) - extractor.model = extractor_model - - def __prep_linear_model(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.requires_grad_(False) - if self.feat_type == "embedding": - extractor.proj.requires_grad_(False) - self.perceptual_model.mlp.requires_grad_(True) - - def __save_lora_weights(self): - for extractor in self.perceptual_model.extractor_list: - save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, - f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') - extractor.model.save_pretrained(save_dir, safe_serialization=False) - adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) - new_adapters_weights = dict() - - for k, v in adapters_weights.items(): - new_k = 'base_model.model.' + k - new_adapters_weights[new_k] = v - torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin')) - - -def run(args, device): - tag = args.tag if len(args.tag) > 0 else "" - training_method = "lora" if args.use_lora else "mlp" - exp_dir = os.path.join(args.log_dir, - f'{tag}_{str(args.model_type)}_{str(args.feat_type)}_{str(training_method)}_' + - f'lr_{str(args.lr)}_batchsize_{str(args.batch_size)}_wd_{str(args.weight_decay)}' - f'_hiddensize_{str(args.hidden_size)}_margin_{str(args.margin)}' - ) - if args.use_lora: - exp_dir += f'_lorar_{str(args.lora_r)}_loraalpha_{str(args.lora_alpha)}_loradropout_{str(args.lora_dropout)}' - - seed_everything(args.seed) - g = torch.Generator() - g.manual_seed(args.seed) - - train_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="train", preprocess=get_preprocess(args.model_type)) - val_dataset = TwoAFCDataset(root_dir=args.dataset_root, split="val", preprocess=get_preprocess(args.model_type)) - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, - worker_init_fn=seed_worker, generator=g) - val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) - - logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) - trainer = Trainer(devices=1, - accelerator='gpu', - log_every_n_steps=10, - logger=logger, - max_epochs=args.epochs, - default_root_dir=exp_dir, - callbacks=ModelCheckpoint(monitor='val_loss_ckpt', - save_top_k=-1, - save_last=True, - filename='{epoch:02d}', - mode='max'), - num_sanity_val_steps=0, - ) - checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}') - os.makedirs(checkpoint_root, exist_ok=True) - with open(os.path.join(checkpoint_root, 'config.yaml'), 'w') as f: - yaml.dump(args, f) - - logging.basicConfig(filename=os.path.join(checkpoint_root, 'exp.log'), level=logging.INFO, force=True) - logging.info("Arguments: ", vars(args)) - - model = LightningPerceptualModel(device=device, train_data_len=len(train_dataset), **vars(args)) - - logging.info("Validating before training") - trainer.validate(model, val_loader) - logging.info("Training") - trainer.fit(model, train_loader, val_loader) - - print("Done :)") - - -if __name__ == '__main__': - args = parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" - run(args, device) diff --git a/training/train.py b/training/train.py index 7814edb..2658937 100644 --- a/training/train.py +++ b/training/train.py @@ -1,18 +1,18 @@ +import os +import configargparse import logging import yaml +import torch import pytorch_lightning as pl +from torch.utils.data import DataLoader +from peft import get_peft_model, LoraConfig, PeftModel from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from util.train_utils import Mean, HingeLoss, seed_worker from util.utils import get_preprocess from dataset.dataset import TwoAFCDataset -from torch.utils.data import DataLoader -import torch -from peft import get_peft_model, LoraConfig, PeftModel from dreamsim import PerceptualModel -import os -import configargparse def parse_args(): diff --git a/util/precompute_sim.py b/util/precompute_sim.py deleted file mode 100644 index 9e162eb..0000000 --- a/util/precompute_sim.py +++ /dev/null @@ -1,69 +0,0 @@ -from PIL import Image -from lightning_fabric import seed_everything -from torch.utils.data import DataLoader - -from dataset.dataset import TwoAFCDataset -from dreamsim import dreamsim -from torchvision import transforms -import torch -import os -from tqdm import tqdm - -from util.train_utils import seed_worker -from util.utils import get_preprocess -import numpy as np -import pickle as pkl -from sklearn.decomposition import PCA - -seed = 1234 -dataset_root = './dataset/nights' -model_type = 'dino_vitb16,clip_vitb16,open_clip_vitb16' -num_workers = 8 -batch_size = 32 - -seed_everything(seed) -g = torch.Generator() -g.manual_seed(seed) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -model, preprocess = dreamsim(pretrained=True, device=device, cache_dir='./models_new') - -train_dataset = TwoAFCDataset(root_dir=dataset_root, split="train", preprocess=get_preprocess(model_type)) -val_dataset = TwoAFCDataset(root_dir=dataset_root, split="val", preprocess=get_preprocess(model_type)) -test_dataset = TwoAFCDataset(root_dir=dataset_root, split="test", preprocess=get_preprocess(model_type)) -train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, - worker_init_fn=seed_worker, generator=g) -val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) -test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) - -data = {'train': {}, 'val': {}, 'test': {}} -all_embeds = [] - -with torch.no_grad(): - for split, loader in [('train', train_loader), ('val', val_loader), ('test', test_loader)]: - for img_ref, img_left, img_right, p, id in tqdm(loader): - img_ref = img_ref.to(device) - img_left = img_left.to(device) - img_right = img_right.to(device) - - embed_ref, embed_0, d0 = model(img_ref, img_left) - _, embed_1, d1 = model(img_ref, img_right) - # - # if split == 'train': - # all_embeds.append(embed_ref) - # all_embeds.append(embed_0) - # all_embeds.append(embed_1) - - for i in range(len(id)): - curr_id = id[i].item() - data[split][curr_id] = [d0[i].item(), d1[i].item()] - -# all_embeds = torch.cat(all_embeds).cpu() -# principal = PCA(n_components=512) -# principal.fit(all_embeds) - -with open('precomputed_sims.pkl', 'wb') as f: - pkl.dump(data, f) - -# with open('precomputed_embeds.pkl', 'wb') as f: -# pkl.dump(principal, f) \ No newline at end of file