diff --git a/.gitignore b/.gitignore index 68bc17f..ed570f8 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ +output/ +models/ +dataset/nights diff --git a/README.md b/README.md index 9944e47..ef3edc1 100644 --- a/README.md +++ b/README.md @@ -17,28 +17,26 @@ Current metrics for perceptual image similarity operate at the level of pixels a DreamSim is a new metric for perceptual image similarity that bridges the gap between "low-level" metrics (e.g. LPIPS, PSNR, SSIM) and "high-level" measures (e.g. CLIP). Our model was trained by concatenating CLIP, OpenCLIP, and DINO embeddings, and then finetuning on human perceptual judgements. We gathered these judgements on a dataset of ~20k image triplets, generated by diffusion models. Our model achieves better alignment with human similarity judgements than existing metrics, and can be used for downstream applications such as image retrieval. ## 🕰️ Coming soon -* JND Dataset release +* ✅ JND Dataset release +* ✅ Compatibility with the most recent version of PEFT * Distilled DreamSim models (i.e. smaller models distilled from the main ensemble) * DreamSim variants trained for higher resolutions -* Compatibility with the most recent version of PEFT -## 🚀 Updates +## 🚀 Newest Updates +**X/XX/24:** Released new versions of the ensemble and single-branch DreamSim models compatible with `peft>=0.2.0`. -**7/14/23:** Released three variants of DreamSim that each only use one finetuned ViT model instead of the full ensemble. These single-branch models provide a ~3x speedup over the full ensemble. - - -Here's how they compare to the full ensemble on NIGHTS (2AFC agreement): -* **Ensemble:** 96.2% -* **OpenCLIP-ViTB/32:** 95.5% -* **DINO-ViTB/16:** 94.6% -* **CLIP-ViTB/32:** 93.9% +Here's how they perform on the NIGHTS validation set: +* **Ensemble:** 96.9% +* **OpenCLIP-ViTB/32:** 95.6% +* **DINO-ViTB/16:** 95.7% +* **CLIP-ViTB/32:** 95.0% ## Table of Contents * [Requirements](#requirements) * [Setup](#setup) * [Usage](#usage) * [Quickstart](#quickstart-perceptual-similarity-metric) - * [Single-branch models](#new-single-branch-models) + * [Single-branch models](#single-branch-models) * [Feature extraction](#feature-extraction) * [Image retrieval](#image-retrieval) * [Perceptual loss function](#perceptual-loss-function) @@ -96,10 +94,10 @@ distance = model(img1, img2) # The model takes an RGB image from [0, 1], size ba To run on example images, run `demo.py`. The script should produce distances (0.424, 0.34). -### (new!) Single-branch models -By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance. +### Single-branch models +By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. **The single-branch models provide a ~3x speedup over the ensemble.** -To load a single-branch model, use the `dreamsim_type` argument. For example: +The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance. To load a single-branch model, use the `dreamsim_type` argument. For example: ``` dreamsim_dino_model, preprocess = dreamsim(pretrained=True, dreamsim_type="dino_vitb16") ``` @@ -153,7 +151,15 @@ DreamSim is trained by fine-tuning on the NIGHTS dataset. For details on the dat Run `./dataset/download_dataset.sh` to download and unzip the NIGHTS dataset into `./dataset/nights`. The unzipped dataset size is 58 GB. -**(new!) Visualize NIGHTS and embeddings with the [Voxel51](https://github.com/voxel51/fiftyone) demo:** +Having trouble with the large file sizes? Run `./dataset/download_chunked_dataset.sh` to download the NIGHTS dataset split into 200 smaller zip files. The output of this script is identical to `download_dataset.sh`. + +### (new!) Download the entire 100k pre-filtered NIGHTS dataset +We only use the 20k unanimous triplets for training and evaluation, but release all 100k triplets (many with few and/or split votes) for research purposes. Run `./dataset/download_unfiltered_dataset.sh` to download and unzip this unfiltered version of the NIGHTS dataset into `./dataset/nights_unfiltered`. The unzipped dataset size is 289 GB. + +### (new!) Download the JND data +Download the just-noticeable difference (JND) votes by running `./dataset/download_jnd_dataset.sh`. The CSV will be downloaded to `./dataset/jnd_votes.csv`. Check out the [Colab](https://colab.research.google.com/drive/1taEOMzFE9g81D9AwH27Uhy2U82tQGAVI?usp=sharing) for an example of loading a JND trial. + +### Visualize NIGHTS and embeddings with the [Voxel51](https://github.com/voxel51/fiftyone) demo: [![FiftyOne](https://img.shields.io/badge/FiftyOne-blue.svg?logo=)](https://try.fiftyone.ai/datasets/nights/samples) ## Experiments diff --git a/configs/eval_baseline.yaml b/configs/eval_baseline.yaml deleted file mode 100644 index ed202ba..0000000 --- a/configs/eval_baseline.yaml +++ /dev/null @@ -1,8 +0,0 @@ -seed: 1234 -baseline_model: dreamsim -baseline_feat_type: cls,embedding,embedding -baseline_stride: 16,16,16 -baseline_output_path: "outputs" -nights_root: ./dataset/nights -num_workers: 10 -batch_size: 16 \ No newline at end of file diff --git a/configs/eval_checkpoint.yaml b/configs/eval_checkpoint.yaml deleted file mode 100644 index ed5030a..0000000 --- a/configs/eval_checkpoint.yaml +++ /dev/null @@ -1,6 +0,0 @@ -seed: 1234 -eval_root: "output/experiment_dir/lightning_logs/version_0" -checkpoint_epoch: 7 -nights_root: ./dataset/nights -num_workers: 10 -batch_size: 16 \ No newline at end of file diff --git a/configs/eval_ensemble.yaml b/configs/eval_ensemble.yaml new file mode 100644 index 0000000..beac463 --- /dev/null +++ b/configs/eval_ensemble.yaml @@ -0,0 +1,17 @@ +tag: "open_clip" + +eval_checkpoint: "/path-to-ckpt/lightning_logs/version_0/checkpoints/epoch-to-eval/" +eval_checkpoint_cfg: "/path-to-ckpt/lightning_logs/version_0/config.yaml" +load_dir: "./models" + +baseline_model: "dino_vitb16,clip_vitb16,open_clip_vitb16" +baseline_feat_type: "cls,embedding,embedding" +baseline_stride: "16,16,16" + +nights_root: "./data/nights" +bapps_root: "./data/2afc/val" +things_root: "./data/things/things_src_images" +things_file: "./data/things/things_valset.txt" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/configs/eval_single_clip.yaml b/configs/eval_single_clip.yaml new file mode 100644 index 0000000..78ae4a1 --- /dev/null +++ b/configs/eval_single_clip.yaml @@ -0,0 +1,17 @@ +tag: "clip" + +eval_checkpoint: "/path-to-ckpt/lightning_logs/version_0/checkpoints/epoch-to-eval/" +eval_checkpoint_cfg: "/path-to-ckpt/lightning_logs/version_0/config.yaml" +load_dir: "./models" + +baseline_model: "clip_vitb32" +baseline_feat_type: "cls" +baseline_stride: "32" + +nights_root: "./data/nights" +bapps_root: "./data/2afc/val" +things_root: "./data/things/things_src_images" +things_file: "./data/things/things_valset.txt" + +batch_size: 256 +num_workers: 10 \ No newline at end of file diff --git a/configs/train_ensemble_model_lora.yaml b/configs/train_ensemble_model_lora.yaml index f74eea9..aad34be 100644 --- a/configs/train_ensemble_model_lora.yaml +++ b/configs/train_ensemble_model_lora.yaml @@ -17,5 +17,5 @@ epochs: 6 margin: 0.05 lora_r: 16 -lora_alpha: 0.5 +lora_alpha: 8 lora_dropout: 0.3 \ No newline at end of file diff --git a/configs/train_single_model_lora.yaml b/configs/train_single_model_lora.yaml index 5be8cc3..05b0888 100644 --- a/configs/train_single_model_lora.yaml +++ b/configs/train_single_model_lora.yaml @@ -1,9 +1,9 @@ seed: 1234 tag: lora_single -log_dir: ./output +log_dir: ./output/new_backbones -model_type: 'mae_vitb16' -feat_type: 'last_layer' +model_type: 'dino_vitb16' +feat_type: 'cls' stride: '16' use_lora: True @@ -17,5 +17,5 @@ epochs: 8 margin: 0.05 lora_r: 16 -lora_alpha: 0.5 -lora_dropout: 0.3 \ No newline at end of file +lora_alpha: 32 +lora_dropout: 0.2 diff --git a/dataset/dataset.py b/dataset/dataset.py index c257902..efcf857 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -18,8 +18,8 @@ def __init__(self, root_dir: str, split: str = "train", load_size: int = 224, self.load_size = load_size self.interpolation = interpolation self.preprocess_fn = get_preprocess_fn(preprocess, self.load_size, self.interpolation) - - if self.split == "train" or self.split == "val": + + if self.split == "train" or self.split == "val" or self.split == "test": self.csv = self.csv[self.csv["split"] == split] elif split == 'test_imagenet': self.csv = self.csv[self.csv['split'] == 'test'] diff --git a/dataset/download_chunked_dataset.sh b/dataset/download_chunked_dataset.sh new file mode 100644 index 0000000..197feab --- /dev/null +++ b/dataset/download_chunked_dataset.sh @@ -0,0 +1,14 @@ +#!/bin/bash +mkdir -p ./dataset +cd dataset + +mkdir -p ref +mkdir -p distort + +# Download NIGHTS dataset +for i in $(seq -f "%03g" 0 99); do + wget https://data.csail.mit.edu/nights/nights_chunked/ref/$i.zip + unzip -q $i.zip -d ref/ && rm $i.zip + wget https://data.csail.mit.edu/nights/nights_chunked/distort/$i.zip + unzip -q $i.zip -d distort/ && rm $i.zip +done diff --git a/dataset/download_dataset.sh b/dataset/download_dataset.sh index 6562f1d..1c8c2db 100644 --- a/dataset/download_dataset.sh +++ b/dataset/download_dataset.sh @@ -1,6 +1,6 @@ #!/bin/bash -mkdir -p ./dataset -cd dataset +# mkdir -p ./dataset +cd /home/fus/data/ # Download NIGHTS dataset wget -O nights.zip https://data.csail.mit.edu/nights/nights.zip diff --git a/dataset/download_jnd_dataset.sh b/dataset/download_jnd_dataset.sh new file mode 100644 index 0000000..04db9cc --- /dev/null +++ b/dataset/download_jnd_dataset.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p ./dataset +cd dataset + +# Download JND data for NIGHTS dataset +wget https://data.csail.mit.edu/nights/jnd_votes.csv diff --git a/dataset/download_unfiltered_dataset.sh b/dataset/download_unfiltered_dataset.sh new file mode 100644 index 0000000..8ce87d5 --- /dev/null +++ b/dataset/download_unfiltered_dataset.sh @@ -0,0 +1,20 @@ +#!/bin/bash +mkdir -p ./dataset_unfiltered +cd dataset_unfiltered + +mkdir -p ref +mkdir -p distort + +# Download NIGHTS dataset + +# store those in a list and loop through wget and unzip and rm +for i in {0..99..25} +do + wget https://data.csail.mit.edu/nights/nights_unfiltered/ref_${i}_$(($i+24)).zip + unzip -q ref_${i}_$(($i+24)).zip -d ref + rm ref_*.zip + + wget https://data.csail.mit.edu/nights/nights_unfiltered/distort_${i}_$(($i+24)).zip + unzip -q distort_${i}_$(($i+24)).zip -d distort + rm distort_*.zip +done diff --git a/dataset/nights b/dataset/nights new file mode 120000 index 0000000..567f57c --- /dev/null +++ b/dataset/nights @@ -0,0 +1 @@ +/home/fus/data/nights \ No newline at end of file diff --git a/dreamsim/config.py b/dreamsim/config.py index a143a9e..c017975 100644 --- a/dreamsim/config.py +++ b/dreamsim/config.py @@ -25,16 +25,10 @@ "lora": True } }, - "lora_config": { - "r": 16, - "lora_alpha": 0.5, - "lora_dropout": 0.3, - "bias": "none", - "target_modules": ['qkv'] - }, "img_size": 224 } +# UPDATE dreamsim_weights = { "ensemble": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/dreamsim_checkpoint.zip", "dino_vitb16": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.2/dreamsim_dino_vitb16_checkpoint.zip", diff --git a/dreamsim/feature_extraction/extractor.py b/dreamsim/feature_extraction/extractor.py index c096782..15900f9 100644 --- a/dreamsim/feature_extraction/extractor.py +++ b/dreamsim/feature_extraction/extractor.py @@ -7,6 +7,7 @@ import os from .load_clip_as_dino import load_clip_as_dino from .load_open_clip_as_dino import load_open_clip_as_dino +from .load_synclr_as_dino import load_synclr_as_dino from .vision_transformer import DINOHead from .load_mae_as_vit import load_mae_as_vit @@ -15,7 +16,7 @@ """ -class ViTExtractor: +class ViTExtractor(nn.Module): """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. We use the following notation in the documentation of the module's methods: @@ -37,6 +38,7 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, load_dir: st :param stride: stride of first convolution layer. small stride -> higher resolution. :param load_dir: location of pretrained ViT checkpoints. """ + super(ViTExtractor, self).__init__() self.model_type = model_type self.device = device self.model = ViTExtractor.create_model(model_type, load_dir) @@ -62,7 +64,10 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module: :param load_dir: location of pretrained ViT checkpoints. :return: the model """ - if 'dino' in model_type: + if 'dinov2' in model_type: + torch.hub.set_dir(load_dir) + model = torch.hub.load('facebookresearch/dinov2', model_type) + elif 'dino' in model_type: torch.hub.set_dir(load_dir) model = torch.hub.load('facebookresearch/dino:main', model_type) if model_type == 'dino_vitb16': @@ -96,6 +101,11 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module: raise ValueError(f"Model {model_type} not supported") elif 'mae' in model_type: model = load_mae_as_vit(model_type, load_dir) + elif 'synclr' in model_type: + if model_type == 'synclr_vitb16': + model = load_synclr_as_dino(16, load_dir) + else: + raise ValueError(f"Model {model_type} not supported") else: raise ValueError(f"Model {model_type} not supported") return model diff --git a/dreamsim/feature_extraction/vision_transformer.py b/dreamsim/feature_extraction/vision_transformer.py index 36143e9..b2fe093 100644 --- a/dreamsim/feature_extraction/vision_transformer.py +++ b/dreamsim/feature_extraction/vision_transformer.py @@ -16,7 +16,7 @@ # This version was taken from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # On Jan 24th, 2022 # Git hash of last commit: 4b96393c4c877d127cff9f077468e4a1cc2b5e2d - + """ Mostly copy-paste from timm library. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import torch.nn as nn trunc_normal_ = lambda *args, **kwargs: None - + def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: @@ -43,6 +43,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ + def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob @@ -121,6 +122,7 @@ def forward(self, x, return_attention=False): class PatchEmbed(nn.Module): """ Image to Patch Embedding """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) @@ -138,6 +140,7 @@ def forward(self, x): class VisionTransformer(nn.Module): """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): @@ -239,7 +242,8 @@ def get_intermediate_layers(self, x, n=1): class DINOHead(nn.Module): - def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, + bottleneck_dim=256): super().__init__() nlayers = max(nlayers, 1) if nlayers == 1: @@ -273,7 +277,7 @@ def forward(self, x): x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) return x - + def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( diff --git a/dreamsim/feature_extraction/vit_wrapper.py b/dreamsim/feature_extraction/vit_wrapper.py deleted file mode 100644 index 793ebe6..0000000 --- a/dreamsim/feature_extraction/vit_wrapper.py +++ /dev/null @@ -1,19 +0,0 @@ -from transformers import PretrainedConfig -from transformers import PreTrainedModel - - -class ViTConfig(PretrainedConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - -class ViTModel(PreTrainedModel): - config_class = ViTConfig - - def __init__(self, model, config): - super().__init__(config) - self.model = model - self.blocks = model.blocks - - def forward(self, x): - return self.model(x) diff --git a/dreamsim/model.py b/dreamsim/model.py index b84f872..b79b8bb 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -1,15 +1,27 @@ +import json + import torch import torch.nn.functional as F +from torch import nn from torchvision import transforms import os + +from util.constants import * from .feature_extraction.extractor import ViTExtractor import yaml +import peft from peft import PeftModel, LoraConfig, get_peft_model -from .feature_extraction.vit_wrapper import ViTConfig, ViTModel from .config import dreamsim_args, dreamsim_weights import os import zipfile +from packaging import version + +peft_version = version.parse(peft.__version__) +min_version = version.parse('0.2.0') +if peft_version < min_version: + raise RuntimeError(f"DreamSim requires peft version {min_version} or greater. " + "Please update peft with 'pip install --upgrade peft'.") class PerceptualModel(torch.nn.Module): def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stride: str = '16', hidden_size: int = 1, @@ -41,7 +53,7 @@ def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stri self.stride_list = [int(x) for x in stride.split(',')] self._validate_args() self.extract_feats_list = [] - self.extractor_list = [] + self.extractor_list = nn.ModuleList() self.embed_size = 0 self.hidden_size = hidden_size self.baseline = baseline @@ -122,23 +134,23 @@ def _preprocess(self, img, model_type): def _get_mean(self, model_type): if "dino" in model_type: - return (0.485, 0.456, 0.406) + return IMAGENET_DEFAULT_MEAN elif "open_clip" in model_type: - return (0.48145466, 0.4578275, 0.40821073) + return OPENAI_CLIP_MEAN elif "clip" in model_type: - return (0.48145466, 0.4578275, 0.40821073) + return OPENAI_CLIP_MEAN elif "mae" in model_type: - return (0.485, 0.456, 0.406) + return IMAGENET_DEFAULT_MEAN def _get_std(self, model_type): if "dino" in model_type: - return (0.229, 0.224, 0.225) + return IMAGENET_DEFAULT_STD elif "open_clip" in model_type: - return (0.26862954, 0.26130258, 0.27577711) + return OPENAI_CLIP_STD elif "clip" in model_type: - return (0.26862954, 0.26130258, 0.27577711) + return OPENAI_CLIP_STD elif "mae" in model_type: - return (0.229, 0.224, 0.225) + return IMAGENET_DEFAULT_STD class MLP(torch.nn.Module): @@ -163,9 +175,8 @@ def download_weights(cache_dir, dreamsim_type): """ dreamsim_required_ckpts = { - "ensemble": ["dino_vitb16_pretrain.pth", "dino_vitb16_lora", - "open_clip_vitb16_pretrain.pth.tar", "open_clip_vitb16_lora", - "clip_vitb16_pretrain.pth.tar", "clip_vitb16_lora"], + "ensemble": ["dino_vitb16_pretrain.pth", "open_clip_vitb16_pretrain.pth.tar", + "clip_vitb16_pretrain.pth.tar", "ensemble_lora"], "dino_vitb16": ["dino_vitb16_pretrain.pth", "dino_vitb16_single_lora"], "open_clip_vitb32": ["open_clip_vitb32_pretrain.pth.tar", "open_clip_vitb32_single_lora"], "clip_vitb32": ["clip_vitb32_pretrain.pth.tar", "clip_vitb32_single_lora"] @@ -213,17 +224,18 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma model_list = dreamsim_args['model_config'][dreamsim_type]['model_type'].split(",") ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir, normalize_embeds=normalize_embeds) - for extractor in ours_model.extractor_list: - lora_config = LoraConfig(**dreamsim_args['lora_config']) - model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config) - extractor.model = model - tag = "" if dreamsim_type == "ensemble" else "single_" + tag = "ensemble_" if dreamsim_type == "ensemble" else f"{model_list[0]}_single_" + + with open(os.path.join(cache_dir, f'{tag}lora', 'adapter_config.json'), 'r') as f: + adapter_config = json.load(f) + lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules'] + lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys}) + ours_model = get_peft_model(ours_model, lora_config) + if pretrained: - for extractor, name in zip(ours_model.extractor_list, model_list): - load_dir = os.path.join(cache_dir, f"{name}_{tag}lora") - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device) - extractor.model.eval().requires_grad_(False) + load_dir = os.path.join(cache_dir, f"{tag}lora") + ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device) ours_model.eval().requires_grad_(False) @@ -262,4 +274,3 @@ def normalize_embedding(embed): 'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768} } - diff --git a/evaluation/eval_datasets.py b/evaluation/eval_datasets.py new file mode 100644 index 0000000..1c410a3 --- /dev/null +++ b/evaluation/eval_datasets.py @@ -0,0 +1,84 @@ +import os +import glob +import numpy as np +from PIL import Image +from torch.utils.data import Dataset +from util.utils import get_preprocess_fn +from torchvision import transforms + +IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"] + +class ThingsDataset(Dataset): + """ + txt_file is expected to be the things_valset.txt list of triplets from the THINGS dataset. + root_dir is expected to be a directory of THINGS images. + """ + def __init__(self, root_dir: str, txt_file: str, preprocess: str, load_size: int = 224, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + with open(txt_file, "r") as f: + self.txt = f.readlines() + self.dataset_root = root_dir + self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) + + def __len__(self): + return len(self.txt) + + def __getitem__(self, idx): + im_1, im_2, im_3 = self.txt[idx].split() + + im_1 = Image.open(os.path.join(self.dataset_root, f"{im_1}.png")) + im_2 = Image.open(os.path.join(self.dataset_root, f"{im_2}.png")) + im_3 = Image.open(os.path.join(self.dataset_root, f"{im_3}.png")) + + im_1 = self.preprocess_fn(im_1) + im_2 = self.preprocess_fn(im_2) + im_3 = self.preprocess_fn(im_3) + + return im_1, im_2, im_3 + + +class BAPPSDataset(Dataset): + def __init__(self, root_dir: str, preprocess: str, load_size: int = 224, + interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC): + """ + root_dir is expected to be the default validation folder of the BAPPS dataset. + """ + data_types = ["cnn", "traditional", "color", "deblur", "superres", "frameinterp"] + + self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation) + self.judge_paths = [] + self.p0_paths = [] + self.p1_paths = [] + self.ref_paths = [] + + for dt in data_types: + list_dir = os.path.join(os.path.join(root_dir, dt), "judge") + for fname in os.scandir(list_dir): + self.judge_paths.append(os.path.join(list_dir, fname.name)) + self.p0_paths.append(os.path.join(os.path.join(os.path.join(root_dir, dt), "p0"), fname.name.split(".")[0] + ".png")) + self.p1_paths.append( + os.path.join(os.path.join(os.path.join(root_dir, dt), "p1"), fname.name.split(".")[0] + ".png")) + self.ref_paths.append( + os.path.join(os.path.join(os.path.join(root_dir, dt), "ref"), fname.name.split(".")[0] + ".png")) + + def __len__(self): + return len(self.judge_paths) + + def __getitem__(self, idx): + judge = np.load(self.judge_paths[idx]) + im_left = self.preprocess_fn(Image.open(self.p0_paths[idx])) + im_right = self.preprocess_fn(Image.open(self.p1_paths[idx])) + im_ref = self.preprocess_fn(Image.open(self.ref_paths[idx])) + return im_ref, im_left, im_right, judge + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + +def get_paths(path): + all_paths = [] + for ext in IMAGE_EXTENSIONS: + all_paths += glob.glob(os.path.join(path, f"**.{ext}")) + return all_paths diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py new file mode 100644 index 0000000..fb5f50e --- /dev/null +++ b/evaluation/eval_percep.py @@ -0,0 +1,144 @@ +import os +import yaml +import logging +import json +import torch +import configargparse +from torch.utils.data import DataLoader +from pytorch_lightning import seed_everything +from dreamsim import PerceptualModel +from dataset.dataset import TwoAFCDataset +from training.train import LightningPerceptualModel +from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset +from evaluation.eval_datasets import ThingsDataset, BAPPSDataset + +log = logging.getLogger("lightning.pytorch") +log.propagate = False +log.setLevel(logging.ERROR) + +def parse_args(): + parser = configargparse.ArgumentParser() + parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') + + ## Run options + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--output', type=str, default="./eval_outputs", help="Dir to save results in.") + parser.add_argument('--tag', type=str, default="", help="Exp name for saving results") + + + ## Checkpoint evaluation options + parser.add_argument('--eval_checkpoint', type=str, help="Path to a checkpoint root.") + parser.add_argument('--eval_checkpoint_cfg', type=str, help="Path to checkpoint config.") + parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints.') + + ## Baseline evaluation options + parser.add_argument('--baseline_model', type=str, default=None) + parser.add_argument('--baseline_feat_type', type=str, default=None) + parser.add_argument('--baseline_stride', type=str, default=None) + + ## Dataset options + parser.add_argument('--nights_root', type=str, default=None, help='path to nights dataset.') + parser.add_argument('--bapps_root', type=str, default=None, help='path to bapps images.') + parser.add_argument('--things_root', type=str, default=None, help='path to things images.') + parser.add_argument('--things_file', type=str, default=None, help='path to things info file.') + parser.add_argument('--df2_root', type=str, default=None, help='path to df2 root.') + parser.add_argument('--df2_gt', type=str, default=None, help='path to df2 ground truth json.') + parser.add_argument('--num_workers', type=int, default=16) + parser.add_argument('--batch_size', type=int, default=4, help='dataset batch size.') + + return parser.parse_args() + +def load_dreamsim_model(args, device="cuda"): + with open(os.path.join(args.eval_checkpoint_cfg), "r") as f: + cfg = yaml.load(f, Loader=yaml.Loader) + + model_cfg = vars(cfg) + model_cfg['load_dir'] = args.load_dir + model = LightningPerceptualModel(**model_cfg) + model.load_lora_weights(args.eval_checkpoint) + model = model.perceptual_model.to(device) + preprocess = "DEFAULT" + return model, preprocess + + +def load_baseline_model(args, device="cuda"): + model = PerceptualModel(model_type=args.baseline_model, feat_type=args.baseline_feat_type, stride=args.baseline_stride, baseline=True, load_dir=args.load_dir) + model = model.to(device) + preprocess = "DEFAULT" + return model, preprocess + +def eval_nights(model, preprocess, device, args): + eval_results = {} + val_dataset = TwoAFCDataset(root_dir=args.nights_root, split="val", preprocess=preprocess) + test_dataset_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_imagenet", preprocess=preprocess) + test_dataset_no_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_no_imagenet", preprocess=preprocess) + total_length = len(test_dataset_no_imagenet) + len(test_dataset_imagenet) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + test_imagenet_loader = DataLoader(test_dataset_imagenet, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + test_no_imagenet_loader = DataLoader(test_dataset_no_imagenet, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + + val_score = score_nights_dataset(model, val_loader, device) + imagenet_score = score_nights_dataset(model, test_imagenet_loader, device) + no_imagenet_score = score_nights_dataset(model, test_no_imagenet_loader, device) + + eval_results['nights_val'] = val_score.item() + eval_results['nights_imagenet'] = imagenet_score.item() + eval_results['nights_no_imagenet'] = no_imagenet_score.item() + eval_results['nights_total'] = (imagenet_score.item() * len(test_dataset_imagenet) + + no_imagenet_score.item() * len(test_dataset_no_imagenet)) / total_length + logging.info(f"NIGHTS (validation 2AFC): {str(eval_results['nights_val'])}") + logging.info(f"NIGHTS (imagenet 2AFC): {str(eval_results['nights_imagenet'])}") + logging.info(f"NIGHTS (no-imagenet 2AFC): {str(eval_results['nights_no_imagenet'])}") + logging.info(f"NIGHTS (total 2AFC): {str(eval_results['nights_total'])}") + return eval_results + +def eval_bapps(model, preprocess, device, args): + test_dataset_bapps = BAPPSDataset(root_dir=args.bapps_root, preprocess=preprocess) + test_loader_bapps = DataLoader(test_dataset_bapps, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + bapps_score = score_bapps_dataset(model, test_loader_bapps, device) + logging.info(f"BAPPS (total 2AFC): {str(bapps_score)}") + return {"bapps_total": bapps_score} + +def eval_things(model, preprocess, device, args): + test_dataset_things = ThingsDataset(root_dir=args.things_root, txt_file=args.things_file, preprocess=preprocess) + test_loader_things = DataLoader(test_dataset_things, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) + things_score = score_things_dataset(model, test_loader_things, device) + logging.info(f"THINGS (total 2AFC): {things_score}") + return {"things_total": things_score} + +def full_eval(eval_model, preprocess, device, args): + results = {} + if args.nights_root is not None: + results['ckpt_nights'] = eval_nights(eval_model, preprocess, device, args) + if args.bapps_root is not None: + results['ckpt_bapps'] = bapps_results = eval_bapps(eval_model, preprocess, device, args) + if args.things_root is not None: + results['ckpt_things'] = eval_things(eval_model, preprocess, device, args) + return results + +def run(args, device): + logging.basicConfig(filename=os.path.join(args.eval_checkpoint, 'eval.log'), level=logging.INFO, force=True) + seed_everything(args.seed) + g = torch.Generator() + g.manual_seed(args.seed) + + os.makedirs(args.output, exist_ok=True) + + full_results = {} + if args.eval_checkpoint is not None: + eval_model, preprocess = load_dreamsim_model(args) + full_results['ckpt'] = full_eval(eval_model, preprocess, device, args) + if args.baseline_model is not None: + baseline_model, baseline_preprocess = load_baseline_model(args) + full_results['baseline'] = full_eval(baseline_model, baseline_preprocess, device, args) + + tag = args.tag + "_" if len(args.tag) > 0 else "" + with open(os.path.join(args.output, f"{tag}eval_results.json"), "w") as f: + json.dump(full_results, f) + + +if __name__ == '__main__': + args = parse_args() + device = "cuda" if torch.cuda.is_available() else "cpu" + run(args, device) + \ No newline at end of file diff --git a/evaluation/score.py b/evaluation/score.py new file mode 100644 index 0000000..bc0b8b9 --- /dev/null +++ b/evaluation/score.py @@ -0,0 +1,81 @@ +import torch +import os +from tqdm import tqdm +import logging +import numpy as np +import json +import torch.nn.functional as F + +def score_nights_dataset(model, test_loader, device): + logging.info("Evaluating NIGHTS dataset.") + d0s = [] + d1s = [] + targets = [] + with torch.no_grad(): + for i, (img_ref, img_left, img_right, target, idx) in tqdm(enumerate(test_loader), total=len(test_loader)): + img_ref, img_left, img_right, target = img_ref.to(device), img_left.to(device), \ + img_right.to(device), target.to(device) + + dist_0 = model(img_ref, img_left) + dist_1 = model(img_ref, img_right) + + if len(dist_0.shape) < 1: + dist_0 = dist_0.unsqueeze(0) + dist_1 = dist_1.unsqueeze(0) + dist_0 = dist_0.unsqueeze(1) + dist_1 = dist_1.unsqueeze(1) + target = target.unsqueeze(1) + + d0s.append(dist_0) + d1s.append(dist_1) + targets.append(target) + + d0s = torch.cat(d0s, dim=0) + d1s = torch.cat(d1s, dim=0) + targets = torch.cat(targets, dim=0) + scores = (d0s < d1s) * (1.0 - targets) + (d1s < d0s) * targets + (d1s == d0s) * 0.5 + twoafc_score = torch.mean(scores, dim=0) + print(f"2AFC score: {str(twoafc_score)}") + return twoafc_score + +def score_things_dataset(model, test_loader, device): + logging.info("Evaluating Things dataset.") + count = 0 + total = 0 + with torch.no_grad(): + for i, (img_1, img_2, img_3) in tqdm(enumerate(test_loader), total=len(test_loader)): + img_1, img_2, img_3 = img_1.to(device), img_2.to(device), img_3.to(device) + + dist_1_2 = model(img_1, img_2) + dist_1_3 = model(img_1, img_3) + dist_2_3 = model(img_2, img_3) + + le_1_3 = torch.le(dist_1_2, dist_1_3) + le_2_3 = torch.le(dist_1_2, dist_2_3) + + count += sum(torch.logical_and(le_1_3, le_2_3)) + total += len(torch.logical_and(le_1_3, le_2_3)) + count = count.detach().cpu().numpy() + accs = count / total + return accs + +def score_bapps_dataset(model, test_loader, device): + logging.info("Evaluating BAPPS dataset.") + + d0s = [] + d1s = [] + ps = [] + with torch.no_grad(): + for i, (im_ref, im_left, im_right, p) in tqdm(enumerate(test_loader), total=len(test_loader)): + im_ref, im_left, im_right, p = im_ref.to(device), im_left.to(device), im_right.to(device), p.to(device) + d0 = model(im_ref, im_left) + d1 = model(im_ref, im_right) + d0s.append(d0) + d1s.append(d1) + ps.append(p.squeeze()) + d0s = torch.cat(d0s, dim=0) + d1s = torch.cat(d1s, dim=0) + ps = torch.cat(ps, dim=0) + scores = (d0s < d1s) * (1.0 - ps) + (d1s < d0s) * ps + (d1s == d0s) * 0.5 + final_score = torch.mean(scores, dim=0) + return final_score diff --git a/requirements.txt b/requirements.txt index b3946bc..7fa525a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ lpips numpy open-clip-torch pandas -peft==0.1.0 +peft>=0.2.0 Pillow pytorch-lightning PyYAML diff --git a/setup.py b/setup.py index 8485f1a..e0fe0e9 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setuptools.setup( name="dreamsim", - version="0.1.3", + version="0.2.0", description="DreamSim similarity metric", long_description=long_description, long_description_content_type="text/markdown", @@ -16,7 +16,7 @@ install_requires=[ "numpy", "open-clip-torch", - "peft==0.1.0", + "peft", "Pillow", "torch", "timm", diff --git a/training/download_models.sh b/training/download_models.sh index a0e5d13..f7d951d 100644 --- a/training/download_models.sh +++ b/training/download_models.sh @@ -2,6 +2,7 @@ mkdir -p ./models cd models +## UDPATE wget -O dreamsim_checkpoint.zip https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/dreamsim_checkpoint.zip wget -O clip_vitb32_pretrain.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/clip_vitb32_pretrain.pth.tar wget -O clipl14_as_dino_vitl.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/clip_vitl14_pretrain.pth.tar diff --git a/training/evaluate.py b/training/evaluate.py deleted file mode 100644 index b58502b..0000000 --- a/training/evaluate.py +++ /dev/null @@ -1,211 +0,0 @@ -from pytorch_lightning import seed_everything -import torch -from dataset.dataset import TwoAFCDataset -from util.utils import get_preprocess -from torch.utils.data import DataLoader -import os -import yaml -import logging -from train import LightningPerceptualModel -from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio -from DISTS_pytorch import DISTS -from dreamsim import PerceptualModel -from tqdm import tqdm -import pickle -import configargparse -from dreamsim import dreamsim - -log = logging.getLogger("lightning.pytorch") -log.propagate = False -log.setLevel(logging.ERROR) - - -def parse_args(): - parser = configargparse.ArgumentParser() - parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path') - - ## Run options - parser.add_argument('--seed', type=int, default=1234) - - ## Checkpoint evaluation options - parser.add_argument('--eval_root', type=str, help="Path to experiment directory containing a checkpoint to " - "evaluate and the experiment config.yaml.") - parser.add_argument('--checkpoint_epoch', type=int, help='Epoch number of the checkpoint to evaluate.') - parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints.') - - ## Baseline evaluation options - parser.add_argument('--baseline_model', type=str, - help='Which ViT model to evaluate. To evaluate an ensemble of models, pass a comma-separated' - 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' - 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') - parser.add_argument('--baseline_feat_type', type=str, - help='What type of feature to extract from the model. If evaluating an ensemble, pass a ' - 'comma-separated list of features (same length as model_type). Accepted feature types: ' - '[cls, embedding, last_layer].') - parser.add_argument('--baseline_stride', type=str, - help='Stride of first convolution layer the model (should match patch size). If finetuning' - 'an ensemble, pass a comma-separated list (same length as model_type).') - parser.add_argument('--baseline_output_path', type=str, help='Path to save evaluation results.') - - ## Dataset options - parser.add_argument('--nights_root', type=str, default='./dataset/nights', help='path to nights dataset.') - parser.add_argument('--num_workers', type=int, default=4) - parser.add_argument('--batch_size', type=int, default=4, help='dataset batch size.') - - return parser.parse_args() - - -def score_nights_dataset(model, test_loader, device): - logging.info("Evaluating NIGHTS dataset.") - d0s = [] - d1s = [] - targets = [] - with torch.no_grad(): - for i, (img_ref, img_left, img_right, target, idx) in tqdm(enumerate(test_loader), total=len(test_loader)): - img_ref, img_left, img_right, target = img_ref.to(device), img_left.to(device), \ - img_right.to(device), target.to(device) - - dist_0 = model(img_ref, img_left) - dist_1 = model(img_ref, img_right) - - if len(dist_0.shape) < 1: - dist_0 = dist_0.unsqueeze(0) - dist_1 = dist_1.unsqueeze(0) - dist_0 = dist_0.unsqueeze(1) - dist_1 = dist_1.unsqueeze(1) - target = target.unsqueeze(1) - - d0s.append(dist_0) - d1s.append(dist_1) - targets.append(target) - - d0s = torch.cat(d0s, dim=0) - d1s = torch.cat(d1s, dim=0) - targets = torch.cat(targets, dim=0) - scores = (d0s < d1s) * (1.0 - targets) + (d1s < d0s) * targets + (d1s == d0s) * 0.5 - twoafc_score = torch.mean(scores, dim=0) - logging.info(f"2AFC score: {str(twoafc_score)}") - return twoafc_score - - -def get_baseline_model(baseline_model, feat_type: str = "cls", stride: str = "16", - load_dir: str = "./models", device: str = "cuda"): - if baseline_model == 'psnr': - def psnr_func(im1, im2): - return -peak_signal_noise_ratio(im1, im2, data_range=1.0, dim=(1, 2, 3), reduction='none') - return psnr_func - - elif baseline_model == 'ssim': - def ssim_func(im1, im2): - return -structural_similarity_index_measure(im1, im2, data_range=1.0, reduction='none') - return ssim_func - - elif baseline_model == 'dists': - dists_metric = DISTS().to(device) - - def dists_func(im1, im2): - distances = dists_metric(im1, im2) - return distances - return dists_func - - elif baseline_model == 'lpips': - import lpips - lpips_fn = lpips.LPIPS(net='alex').eval() - - def lpips_func(im1, im2): - distances = lpips_fn(im1.to(device), im2.to(device)).reshape(-1) - return distances - return lpips_func - - elif 'clip' in baseline_model or 'dino' in baseline_model or "open_clip" in baseline_model or "mae" in baseline_model: - perceptual_model = PerceptualModel(feat_type=feat_type, model_type=baseline_model, stride=stride, - baseline=True, load_dir=load_dir, device=device) - for extractor in perceptual_model.extractor_list: - extractor.model.eval() - return perceptual_model - - elif baseline_model == "dreamsim": - dreamsim_model, preprocess = dreamsim(pretrained=True, cache_dir=load_dir) - return dreamsim_model - - else: - raise ValueError(f"Model {baseline_model} not supported.") - - -def run(args, device): - seed_everything(args.seed) - g = torch.Generator() - g.manual_seed(args.seed) - - if args.checkpoint_epoch is not None: - if args.baseline_model is not None: - raise ValueError("Cannot run baseline evaluation with a checkpoint.") - args_path = os.path.join(args.eval_root, "config.yaml") - logging.basicConfig(filename=os.path.join(args.eval_root, 'eval.log'), level=logging.INFO, force=True) - with open(args_path) as f: - logging.info(f"Loading checkpoint arguments from {args_path}") - eval_args = yaml.load(f, Loader=yaml.Loader) - - eval_args.load_dir = args.load_dir - model = LightningPerceptualModel(**vars(eval_args), device=device) - logging.info(f"Loading checkpoint from {args.eval_root} using epoch {args.checkpoint_epoch}") - - checkpoint_root = os.path.join(args.eval_root, "checkpoints") - checkpoint_path = os.path.join(checkpoint_root, f"epoch={(args.checkpoint_epoch):02d}.ckpt") - sd = torch.load(checkpoint_path) - model.load_state_dict(sd["state_dict"]) - if eval_args.use_lora: - model.load_lora_weights(checkpoint_root=checkpoint_root, epoch_load=args.checkpoint_epoch) - model = model.perceptual_model - for extractor in model.extractor_list: - extractor.model.eval() - model = model.to(device) - output_path = checkpoint_root - model_type = eval_args.model_type - - elif args.baseline_model is not None: - if not os.path.exists(args.baseline_output_path): - os.mkdir(args.baseline_output_path) - logging.basicConfig(filename=os.path.join(args.baseline_output_path, 'eval.log'), level=logging.INFO, - force=True) - model = get_baseline_model(args.baseline_model, args.baseline_feat_type, args.baseline_stride, args.load_dir, - device) - output_path = args.baseline_output_path - model_type = args.baseline_model - - else: - raise ValueError("Must specify one of checkpoint_path or baseline_model") - - eval_results = {} - - test_dataset_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_imagenet", - preprocess=get_preprocess(model_type)) - test_dataset_no_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_no_imagenet", - preprocess=get_preprocess(model_type)) - total_length = len(test_dataset_no_imagenet) + len(test_dataset_imagenet) - test_imagenet_loader = DataLoader(test_dataset_imagenet, batch_size=args.batch_size, - num_workers=args.num_workers, shuffle=False) - test_no_imagenet_loader = DataLoader(test_dataset_no_imagenet, batch_size=args.batch_size, - num_workers=args.num_workers, shuffle=False) - - imagenet_score = score_nights_dataset(model, test_imagenet_loader, device) - no_imagenet_score = score_nights_dataset(model, test_no_imagenet_loader, device) - - eval_results['nights_imagenet'] = imagenet_score.item() - eval_results['nights_no_imagenet'] = no_imagenet_score.item() - eval_results['nights_total'] = (imagenet_score.item() * len(test_dataset_imagenet) + - no_imagenet_score.item() * len(test_dataset_no_imagenet)) / total_length - logging.info(f"Combined 2AFC score: {str(eval_results['nights_total'])}") - - logging.info(f"Saving to {os.path.join(output_path, 'eval_results.pkl')}") - with open(os.path.join(output_path, 'eval_results.pkl'), "wb") as f: - pickle.dump(eval_results, f) - - print("Done :)") - - -if __name__ == "__main__": - args = parse_args() - device = "cuda" if torch.cuda.is_available() else "cpu" - run(args, device) \ No newline at end of file diff --git a/training/train.py b/training/train.py index 1686b3b..2658937 100644 --- a/training/train.py +++ b/training/train.py @@ -1,24 +1,18 @@ +import os +import configargparse import logging import yaml +import torch import pytorch_lightning as pl +from torch.utils.data import DataLoader +from peft import get_peft_model, LoraConfig, PeftModel from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from util.train_utils import Mean, HingeLoss, seed_worker from util.utils import get_preprocess from dataset.dataset import TwoAFCDataset -from torch.utils.data import DataLoader -import torch -from peft import get_peft_model, LoraConfig, PeftModel from dreamsim import PerceptualModel -from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig -import os -import configargparse -from tqdm import tqdm - -log = logging.getLogger("lightning.pytorch") -log.propagate = False -log.setLevel(logging.INFO) def parse_args(): @@ -30,6 +24,9 @@ def parse_args(): parser.add_argument('--tag', type=str, default='', help='tag for experiments (ex. experiment name)') parser.add_argument('--log_dir', type=str, default="./logs", help='path to save model checkpoints and logs') parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints') + parser.add_argument('--save_mode', type=str, default="all", help='whether to save only LoRA adapter weights, ' + 'entire model, or both. Accepted ' + 'options: [adapter_only, entire_model, all]') ## Model options parser.add_argument('--model_type', type=str, default='dino_vitb16', @@ -68,10 +65,11 @@ def parse_args(): class LightningPerceptualModel(pl.LightningModule): - def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", hidden_size: int = 1, + def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stride: str = "16", + hidden_size: int = 1, lr: float = 0.0003, use_lora: bool = False, margin: float = 0.05, lora_r: int = 16, lora_alpha: float = 0.5, lora_dropout: float = 0.3, weight_decay: float = 0.0, train_data_len: int = 1, - load_dir: str = "./models", device: str = "cuda", + load_dir: str = "./models", device: str = "cuda", save_mode: str = "all", **kwargs): super().__init__() self.save_hyperparameters() @@ -88,12 +86,16 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri self.lora_alpha = lora_alpha self.lora_dropout = lora_dropout self.train_data_len = train_data_len + self.save_mode = save_mode + + self.__validate_save_mode() self.started = False self.val_metrics = {'loss': Mean().to(device), 'score': Mean().to(device)} self.__reset_val_metrics() - self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, stride=self.stride, + self.perceptual_model = PerceptualModel(feat_type=self.feat_type, model_type=self.model_type, + stride=self.stride, hidden_size=self.hidden_size, lora=self.use_lora, load_dir=load_dir, device=device) if self.use_lora: @@ -101,6 +103,11 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri else: self.__prep_linear_model() + pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters()) + pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad) + print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} ' + f'| % Trainable: {pytorch_total_trainable_params / pytorch_total_params * 100}') + self.criterion = HingeLoss(margin=self.margin, device=device) self.epoch_loss_train = 0.0 @@ -140,15 +147,12 @@ def on_train_epoch_start(self): def on_train_epoch_end(self): epoch = self.current_epoch + 1 if self.started else 0 - self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, epoch) + self.logger.experiment.add_scalar(f'train_loss/', self.epoch_loss_train / self.trainer.num_training_batches, + epoch) self.logger.experiment.add_scalar(f'train_2afc_acc/', self.train_num_correct / self.train_data_len, epoch) if self.use_lora: self.__save_lora_weights() - def on_validation_start(self): - for extractor in self.perceptual_model.extractor_list: - extractor.model.eval() - def on_validation_epoch_start(self): self.__reset_val_metrics() @@ -159,7 +163,7 @@ def on_validation_epoch_end(self): self.log(f'val_acc_ckpt', score, logger=False) self.log(f'val_loss_ckpt', loss, logger=False) - # log for tensorboard + self.logger.experiment.add_scalar(f'val_2afc_acc/', score, epoch) self.logger.experiment.add_scalar(f'val_loss/', loss, epoch) @@ -175,28 +179,31 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(params, lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) return [optimizer] - def load_lora_weights(self, checkpoint_root, epoch_load): - for extractor in self.perceptual_model.extractor_list: - load_dir = os.path.join(checkpoint_root, - f'epoch_{epoch_load}_{extractor.model_type}') - extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(extractor.device) + def load_lora_weights(self, checkpoint_root, epoch_load=None): + if self.save_mode in {'adapter_only', 'all'}: + if epoch_load is not None: + checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}') + + logging.info(f'Loading adapter weights from {checkpoint_root}') + self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device) + else: + logging.info(f'Loading entire model from {checkpoint_root}') + sd = torch.load(os.path.join(checkpoint_root, f'epoch={epoch_load:02d}.ckpt'))['state_dict'] + self.load_state_dict(sd, strict=True) def __reset_val_metrics(self): for k, v in self.val_metrics.items(): v.reset() def __prep_lora_model(self): - for extractor in self.perceptual_model.extractor_list: - config = LoraConfig( - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout, - bias='none', - target_modules=['qkv'] - ) - extractor_model = get_peft_model(ViTModel(extractor.model, ViTConfig()), - config).to(extractor.device) - extractor.model = extractor_model + config = LoraConfig( + r=self.lora_r, + lora_alpha=self.lora_alpha, + lora_dropout=self.lora_dropout, + bias='none', + target_modules=['qkv'] + ) + self.perceptual_model = get_peft_model(self.perceptual_model, config) def __prep_linear_model(self): for extractor in self.perceptual_model.extractor_list: @@ -206,17 +213,14 @@ def __prep_linear_model(self): self.perceptual_model.mlp.requires_grad_(True) def __save_lora_weights(self): - for extractor in self.perceptual_model.extractor_list: - save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, - f'epoch_{self.trainer.current_epoch}_{extractor.model_type}') - extractor.model.save_pretrained(save_dir) - adapters_weights = torch.load(os.path.join(save_dir, 'adapter_model.bin')) - new_adapters_weights = dict() + if self.save_mode != 'entire_model': + save_dir = os.path.join(self.trainer.callbacks[-1].dirpath, f'epoch_{self.trainer.current_epoch}') + self.perceptual_model.save_pretrained(save_dir) - for k, v in adapters_weights.items(): - new_k = 'base_model.model.' + k - new_adapters_weights[new_k] = v - torch.save(new_adapters_weights, os.path.join(save_dir, 'adapter_model.bin')) + def __validate_save_mode(self): + save_options = {'adapter_only', 'entire_model', 'all'} + assert self.save_mode in save_options, f'save_mode must be one of {save_options}, got {self.save_mode}' + logging.info(f'Using save mode: {self.save_mode}') def run(args, device): @@ -241,17 +245,18 @@ def run(args, device): val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) logger = TensorBoardLogger(save_dir=exp_dir, default_hp_metric=False) + checkpointer = ModelCheckpoint(monitor='val_loss_ckpt', + save_top_k=-1, + save_last=True, + filename='{epoch:02d}', + mode='min') if args.save_mode != 'adapter_only' else None trainer = Trainer(devices=1, accelerator='gpu', log_every_n_steps=10, logger=logger, max_epochs=args.epochs, default_root_dir=exp_dir, - callbacks=ModelCheckpoint(monitor='val_loss_ckpt', - save_top_k=-1, - save_last=True, - filename='{epoch:02d}', - mode='max'), + callbacks=checkpointer, num_sanity_val_steps=0, ) checkpoint_root = os.path.join(exp_dir, 'lightning_logs', f'version_{trainer.logger.version}') @@ -276,9 +281,3 @@ def run(args, device): args = parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" run(args, device) - - - - - - diff --git a/util/constants.py b/util/constants.py new file mode 100644 index 0000000..a9c6872 --- /dev/null +++ b/util/constants.py @@ -0,0 +1,5 @@ +# use timm names from https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/constants.py +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)