diff --git a/configs/distill_lora.yaml b/configs/distill_lora.yaml deleted file mode 100644 index d8e145d..0000000 --- a/configs/distill_lora.yaml +++ /dev/null @@ -1,21 +0,0 @@ -seed: 1234 -tag: lora_distill -log_dir: /home/fus/repos/dreamsim-dev/output/dev - -model_type: 'clip_vitb32' -feat_type: 'embedding' -stride: '32' -use_lora: True - -dataset_root: ./dataset/nights -num_workers: 4 - -lr: 0.0003 -weight_decay: 0.0 -batch_size: 32 -epochs: 15 -margin: 0.05 - -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0 \ No newline at end of file diff --git a/configs/train_ensemble_model_lora.yaml b/configs/train_ensemble_model_lora.yaml index f74eea9..aad34be 100644 --- a/configs/train_ensemble_model_lora.yaml +++ b/configs/train_ensemble_model_lora.yaml @@ -17,5 +17,5 @@ epochs: 6 margin: 0.05 lora_r: 16 -lora_alpha: 0.5 +lora_alpha: 8 lora_dropout: 0.3 \ No newline at end of file diff --git a/configs/train_single_model_lora.yaml b/configs/train_single_model_lora.yaml index a44955c..05b0888 100644 --- a/configs/train_single_model_lora.yaml +++ b/configs/train_single_model_lora.yaml @@ -2,7 +2,7 @@ seed: 1234 tag: lora_single log_dir: ./output/new_backbones -model_type: 'synclr_vitb16' +model_type: 'dino_vitb16' feat_type: 'cls' stride: '16' use_lora: True @@ -17,5 +17,5 @@ epochs: 8 margin: 0.05 lora_r: 16 -lora_alpha: 16 -lora_dropout: 0.1 +lora_alpha: 32 +lora_dropout: 0.2 diff --git a/dataset/download_chunked_dataset.sh b/dataset/download_chunked_dataset.sh new file mode 100644 index 0000000..197feab --- /dev/null +++ b/dataset/download_chunked_dataset.sh @@ -0,0 +1,14 @@ +#!/bin/bash +mkdir -p ./dataset +cd dataset + +mkdir -p ref +mkdir -p distort + +# Download NIGHTS dataset +for i in $(seq -f "%03g" 0 99); do + wget https://data.csail.mit.edu/nights/nights_chunked/ref/$i.zip + unzip -q $i.zip -d ref/ && rm $i.zip + wget https://data.csail.mit.edu/nights/nights_chunked/distort/$i.zip + unzip -q $i.zip -d distort/ && rm $i.zip +done diff --git a/dataset/download_jnd_dataset.sh b/dataset/download_jnd_dataset.sh new file mode 100644 index 0000000..04db9cc --- /dev/null +++ b/dataset/download_jnd_dataset.sh @@ -0,0 +1,6 @@ +#!/bin/bash +mkdir -p ./dataset +cd dataset + +# Download JND data for NIGHTS dataset +wget https://data.csail.mit.edu/nights/jnd_votes.csv diff --git a/dataset/download_unfiltered_dataset.sh b/dataset/download_unfiltered_dataset.sh new file mode 100644 index 0000000..8ce87d5 --- /dev/null +++ b/dataset/download_unfiltered_dataset.sh @@ -0,0 +1,20 @@ +#!/bin/bash +mkdir -p ./dataset_unfiltered +cd dataset_unfiltered + +mkdir -p ref +mkdir -p distort + +# Download NIGHTS dataset + +# store those in a list and loop through wget and unzip and rm +for i in {0..99..25} +do + wget https://data.csail.mit.edu/nights/nights_unfiltered/ref_${i}_$(($i+24)).zip + unzip -q ref_${i}_$(($i+24)).zip -d ref + rm ref_*.zip + + wget https://data.csail.mit.edu/nights/nights_unfiltered/distort_${i}_$(($i+24)).zip + unzip -q distort_${i}_$(($i+24)).zip -d distort + rm distort_*.zip +done diff --git a/dreamsim/config.py b/dreamsim/config.py index a143a9e..ede32e0 100644 --- a/dreamsim/config.py +++ b/dreamsim/config.py @@ -25,13 +25,6 @@ "lora": True } }, - "lora_config": { - "r": 16, - "lora_alpha": 0.5, - "lora_dropout": 0.3, - "bias": "none", - "target_modules": ['qkv'] - }, "img_size": 224 } diff --git a/dreamsim/feature_extraction/vision_transformer.py b/dreamsim/feature_extraction/vision_transformer.py index 36143e9..b2fe093 100644 --- a/dreamsim/feature_extraction/vision_transformer.py +++ b/dreamsim/feature_extraction/vision_transformer.py @@ -16,7 +16,7 @@ # This version was taken from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # On Jan 24th, 2022 # Git hash of last commit: 4b96393c4c877d127cff9f077468e4a1cc2b5e2d - + """ Mostly copy-paste from timm library. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import torch.nn as nn trunc_normal_ = lambda *args, **kwargs: None - + def drop_path(x, drop_prob: float = 0., training: bool = False): if drop_prob == 0. or not training: @@ -43,6 +43,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ + def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob @@ -121,6 +122,7 @@ def forward(self, x, return_attention=False): class PatchEmbed(nn.Module): """ Image to Patch Embedding """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) @@ -138,6 +140,7 @@ def forward(self, x): class VisionTransformer(nn.Module): """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): @@ -239,7 +242,8 @@ def get_intermediate_layers(self, x, n=1): class DINOHead(nn.Module): - def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, + bottleneck_dim=256): super().__init__() nlayers = max(nlayers, 1) if nlayers == 1: @@ -273,7 +277,7 @@ def forward(self, x): x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) return x - + def vit_tiny(patch_size=16, **kwargs): model = VisionTransformer( diff --git a/dreamsim/model.py b/dreamsim/model.py index e7d3364..b79b8bb 100644 --- a/dreamsim/model.py +++ b/dreamsim/model.py @@ -1,3 +1,5 @@ +import json + import torch import torch.nn.functional as F from torch import nn @@ -7,11 +9,19 @@ from util.constants import * from .feature_extraction.extractor import ViTExtractor import yaml +import peft from peft import PeftModel, LoraConfig, get_peft_model from .config import dreamsim_args, dreamsim_weights import os import zipfile +from packaging import version + +peft_version = version.parse(peft.__version__) +min_version = version.parse('0.2.0') +if peft_version < min_version: + raise RuntimeError(f"DreamSim requires peft version {min_version} or greater. " + "Please update peft with 'pip install --upgrade peft'.") class PerceptualModel(torch.nn.Module): def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stride: str = '16', hidden_size: int = 1, @@ -165,9 +175,8 @@ def download_weights(cache_dir, dreamsim_type): """ dreamsim_required_ckpts = { - "ensemble": ["dino_vitb16_pretrain.pth", "dino_vitb16_lora", - "open_clip_vitb16_pretrain.pth.tar", "open_clip_vitb16_lora", - "clip_vitb16_pretrain.pth.tar", "clip_vitb16_lora"], + "ensemble": ["dino_vitb16_pretrain.pth", "open_clip_vitb16_pretrain.pth.tar", + "clip_vitb16_pretrain.pth.tar", "ensemble_lora"], "dino_vitb16": ["dino_vitb16_pretrain.pth", "dino_vitb16_single_lora"], "open_clip_vitb32": ["open_clip_vitb32_pretrain.pth.tar", "open_clip_vitb32_single_lora"], "clip_vitb32": ["clip_vitb32_pretrain.pth.tar", "clip_vitb32_single_lora"] @@ -216,10 +225,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir, normalize_embeds=normalize_embeds) - lora_config = LoraConfig(**dreamsim_args['lora_config']) + tag = "ensemble_" if dreamsim_type == "ensemble" else f"{model_list[0]}_single_" + + with open(os.path.join(cache_dir, f'{tag}lora', 'adapter_config.json'), 'r') as f: + adapter_config = json.load(f) + lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules'] + lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys}) ours_model = get_peft_model(ours_model, lora_config) - tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}" if pretrained: load_dir = os.path.join(cache_dir, f"{tag}lora") ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device) @@ -251,7 +264,6 @@ def normalize_embedding(embed): 'dino_vits16': {'cls': 384, 'last_layer': 384}, 'dino_vitb8': {'cls': 768, 'last_layer': 768}, 'dino_vitb16': {'cls': 768, 'last_layer': 768}, - 'dinov2_vitb14': {'cls': 768, 'last_layer': 768}, 'clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 512}, 'clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}, @@ -260,6 +272,5 @@ def normalize_embedding(embed): 'mae_vith14': {'cls': 1280, 'last_layer': 1280}, 'open_clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768}, 'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768}, - 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}, - 'synclr_vitb16': {'cls': 768, 'last_layer': 768}, + 'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768} } diff --git a/requirements.txt b/requirements.txt index 01e366b..7fa525a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ lpips numpy open-clip-torch pandas -peft>=0.4.0 +peft>=0.2.0 Pillow pytorch-lightning PyYAML diff --git a/training/train.py b/training/train.py index 9dc7767..7814edb 100644 --- a/training/train.py +++ b/training/train.py @@ -33,7 +33,7 @@ def parse_args(): help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated' 'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, ' 'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, ' - 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]') + 'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]') parser.add_argument('--feat_type', type=str, default='cls', help='What type of feature to extract from the model. If finetuning an ensemble, pass a ' 'comma-separated list of features (same length as model_type). Accepted feature types: ' @@ -183,7 +183,7 @@ def load_lora_weights(self, checkpoint_root, epoch_load=None): if self.save_mode in {'adapter_only', 'all'}: if epoch_load is not None: checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}') - + logging.info(f'Loading adapter weights from {checkpoint_root}') self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device) else: