Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ssundaram21 committed Jul 30, 2024
2 parents 1d9790e + 9fd48bf commit 38b8f37
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 63 deletions.
38 changes: 22 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,26 @@ Current metrics for perceptual image similarity operate at the level of pixels a
DreamSim is a new metric for perceptual image similarity that bridges the gap between "low-level" metrics (e.g. LPIPS, PSNR, SSIM) and "high-level" measures (e.g. CLIP). Our model was trained by concatenating CLIP, OpenCLIP, and DINO embeddings, and then finetuning on human perceptual judgements. We gathered these judgements on a dataset of ~20k image triplets, generated by diffusion models. Our model achieves better alignment with human similarity judgements than existing metrics, and can be used for downstream applications such as image retrieval.

## 🕰️ Coming soon
* JND Dataset release
* ✅ JND Dataset release
* ✅ Compatibility with the most recent version of PEFT
* Distilled DreamSim models (i.e. smaller models distilled from the main ensemble)
* DreamSim variants trained for higher resolutions
* Compatibility with the most recent version of PEFT

## 🚀 Updates
## 🚀 Newest Updates
**X/XX/24:** Released new versions of the ensemble and single-branch DreamSim models compatible with `peft>=0.2.0`.

**7/14/23:** Released three variants of DreamSim that each only use one finetuned ViT model instead of the full ensemble. These single-branch models provide a ~3x speedup over the full ensemble.


Here's how they compare to the full ensemble on NIGHTS (2AFC agreement):
* **Ensemble:** 96.2%
* **OpenCLIP-ViTB/32:** 95.5%
* **DINO-ViTB/16:** 94.6%
* **CLIP-ViTB/32:** 93.9%
Here's how they perform on the NIGHTS validation set:
* **Ensemble:** 96.9%
* **OpenCLIP-ViTB/32:** 95.6%
* **DINO-ViTB/16:** 95.7%
* **CLIP-ViTB/32:** 95.0%

## Table of Contents
* [Requirements](#requirements)
* [Setup](#setup)
* [Usage](#usage)
* [Quickstart](#quickstart-perceptual-similarity-metric)
* [Single-branch models](#new-single-branch-models)
* [Single-branch models](#single-branch-models)
* [Feature extraction](#feature-extraction)
* [Image retrieval](#image-retrieval)
* [Perceptual loss function](#perceptual-loss-function)
Expand Down Expand Up @@ -96,10 +94,10 @@ distance = model(img1, img2) # The model takes an RGB image from [0, 1], size ba

To run on example images, run `demo.py`. The script should produce distances (0.424, 0.34).

### (new!) Single-branch models
By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance.
### Single-branch models
By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. **The single-branch models provide a ~3x speedup over the ensemble.**

To load a single-branch model, use the `dreamsim_type` argument. For example:
The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance. To load a single-branch model, use the `dreamsim_type` argument. For example:
```
dreamsim_dino_model, preprocess = dreamsim(pretrained=True, dreamsim_type="dino_vitb16")
```
Expand Down Expand Up @@ -153,7 +151,15 @@ DreamSim is trained by fine-tuning on the NIGHTS dataset. For details on the dat

Run `./dataset/download_dataset.sh` to download and unzip the NIGHTS dataset into `./dataset/nights`. The unzipped dataset size is 58 GB.

**(new!) Visualize NIGHTS and embeddings with the [Voxel51](https://github.com/voxel51/fiftyone) demo:**
Having trouble with the large file sizes? Run `./dataset/download_chunked_dataset.sh` to download the NIGHTS dataset split into 200 smaller zip files. The output of this script is identical to `download_dataset.sh`.

### (new!) Download the entire 100k pre-filtered NIGHTS dataset
We only use the 20k unanimous triplets for training and evaluation, but release all 100k triplets (many with few and/or split votes) for research purposes. Run `./dataset/download_unfiltered_dataset.sh` to download and unzip this unfiltered version of the NIGHTS dataset into `./dataset/nights_unfiltered`. The unzipped dataset size is 289 GB.

### (new!) Download the JND data
Download the just-noticeable difference (JND) votes by running `./dataset/download_jnd_dataset.sh`. The CSV will be downloaded to `./dataset/jnd_votes.csv`. Check out the [Colab](https://colab.research.google.com/drive/1taEOMzFE9g81D9AwH27Uhy2U82tQGAVI?usp=sharing) for an example of loading a JND trial.

### Visualize NIGHTS and embeddings with the [Voxel51](https://github.com/voxel51/fiftyone) demo:
[![FiftyOne](https://img.shields.io/badge/FiftyOne-blue.svg?logo=)](https://try.fiftyone.ai/datasets/nights/samples)

## Experiments
Expand Down
21 changes: 0 additions & 21 deletions configs/distill_lora.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/train_ensemble_model_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ epochs: 6
margin: 0.05

lora_r: 16
lora_alpha: 0.5
lora_alpha: 8
lora_dropout: 0.3
6 changes: 3 additions & 3 deletions configs/train_single_model_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ seed: 1234
tag: lora_single
log_dir: ./output/new_backbones

model_type: 'synclr_vitb16'
model_type: 'dino_vitb16'
feat_type: 'cls'
stride: '16'
use_lora: True
Expand All @@ -17,5 +17,5 @@ epochs: 8
margin: 0.05

lora_r: 16
lora_alpha: 16
lora_dropout: 0.1
lora_alpha: 32
lora_dropout: 0.2
14 changes: 14 additions & 0 deletions dataset/download_chunked_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset

mkdir -p ref
mkdir -p distort

# Download NIGHTS dataset
for i in $(seq -f "%03g" 0 99); do
wget https://data.csail.mit.edu/nights/nights_chunked/ref/$i.zip
unzip -q $i.zip -d ref/ && rm $i.zip
wget https://data.csail.mit.edu/nights/nights_chunked/distort/$i.zip
unzip -q $i.zip -d distort/ && rm $i.zip
done
6 changes: 6 additions & 0 deletions dataset/download_jnd_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset

# Download JND data for NIGHTS dataset
wget https://data.csail.mit.edu/nights/jnd_votes.csv
20 changes: 20 additions & 0 deletions dataset/download_unfiltered_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash
mkdir -p ./dataset_unfiltered
cd dataset_unfiltered

mkdir -p ref
mkdir -p distort

# Download NIGHTS dataset

# store those in a list and loop through wget and unzip and rm
for i in {0..99..25}
do
wget https://data.csail.mit.edu/nights/nights_unfiltered/ref_${i}_$(($i+24)).zip
unzip -q ref_${i}_$(($i+24)).zip -d ref
rm ref_*.zip

wget https://data.csail.mit.edu/nights/nights_unfiltered/distort_${i}_$(($i+24)).zip
unzip -q distort_${i}_$(($i+24)).zip -d distort
rm distort_*.zip
done
7 changes: 0 additions & 7 deletions dreamsim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
"lora": True
}
},
"lora_config": {
"r": 16,
"lora_alpha": 0.5,
"lora_dropout": 0.3,
"bias": "none",
"target_modules": ['qkv']
},
"img_size": 224
}

Expand Down
12 changes: 8 additions & 4 deletions dreamsim/feature_extraction/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# This version was taken from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# On Jan 24th, 2022
# Git hash of last commit: 4b96393c4c877d127cff9f077468e4a1cc2b5e2d

"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
Expand All @@ -27,7 +27,7 @@
import torch.nn as nn

trunc_normal_ = lambda *args, **kwargs: None


def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
Expand All @@ -43,6 +43,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
Expand Down Expand Up @@ -121,6 +122,7 @@ def forward(self, x, return_attention=False):
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
Expand All @@ -138,6 +140,7 @@ def forward(self, x):

class VisionTransformer(nn.Module):
""" Vision Transformer """

def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
Expand Down Expand Up @@ -239,7 +242,8 @@ def get_intermediate_layers(self, x, n=1):


class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
Expand Down Expand Up @@ -273,7 +277,7 @@ def forward(self, x):
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x


def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(
Expand Down
27 changes: 19 additions & 8 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import torch
import torch.nn.functional as F
from torch import nn
Expand All @@ -7,11 +9,19 @@
from util.constants import *
from .feature_extraction.extractor import ViTExtractor
import yaml
import peft
from peft import PeftModel, LoraConfig, get_peft_model
from .config import dreamsim_args, dreamsim_weights
import os
import zipfile
from packaging import version

peft_version = version.parse(peft.__version__)
min_version = version.parse('0.2.0')

if peft_version < min_version:
raise RuntimeError(f"DreamSim requires peft version {min_version} or greater. "
"Please update peft with 'pip install --upgrade peft'.")

class PerceptualModel(torch.nn.Module):
def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stride: str = '16', hidden_size: int = 1,
Expand Down Expand Up @@ -165,9 +175,8 @@ def download_weights(cache_dir, dreamsim_type):
"""

dreamsim_required_ckpts = {
"ensemble": ["dino_vitb16_pretrain.pth", "dino_vitb16_lora",
"open_clip_vitb16_pretrain.pth.tar", "open_clip_vitb16_lora",
"clip_vitb16_pretrain.pth.tar", "clip_vitb16_lora"],
"ensemble": ["dino_vitb16_pretrain.pth", "open_clip_vitb16_pretrain.pth.tar",
"clip_vitb16_pretrain.pth.tar", "ensemble_lora"],
"dino_vitb16": ["dino_vitb16_pretrain.pth", "dino_vitb16_single_lora"],
"open_clip_vitb32": ["open_clip_vitb32_pretrain.pth.tar", "open_clip_vitb32_single_lora"],
"clip_vitb32": ["clip_vitb32_pretrain.pth.tar", "clip_vitb32_single_lora"]
Expand Down Expand Up @@ -216,10 +225,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma
ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir,
normalize_embeds=normalize_embeds)

lora_config = LoraConfig(**dreamsim_args['lora_config'])
tag = "ensemble_" if dreamsim_type == "ensemble" else f"{model_list[0]}_single_"

with open(os.path.join(cache_dir, f'{tag}lora', 'adapter_config.json'), 'r') as f:
adapter_config = json.load(f)
lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules']
lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys})
ours_model = get_peft_model(ours_model, lora_config)

tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}"
if pretrained:
load_dir = os.path.join(cache_dir, f"{tag}lora")
ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device)
Expand Down Expand Up @@ -251,7 +264,6 @@ def normalize_embedding(embed):
'dino_vits16': {'cls': 384, 'last_layer': 384},
'dino_vitb8': {'cls': 768, 'last_layer': 768},
'dino_vitb16': {'cls': 768, 'last_layer': 768},
'dinov2_vitb14': {'cls': 768, 'last_layer': 768},
'clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 512},
'clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768},
Expand All @@ -260,6 +272,5 @@ def normalize_embedding(embed):
'mae_vith14': {'cls': 1280, 'last_layer': 1280},
'open_clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768},
'synclr_vitb16': {'cls': 768, 'last_layer': 768},
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ lpips
numpy
open-clip-torch
pandas
peft>=0.4.0
peft>=0.2.0
Pillow
pytorch-lightning
PyYAML
Expand Down
4 changes: 2 additions & 2 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_args():
help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated'
'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, '
'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, '
'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]')
'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]')
parser.add_argument('--feat_type', type=str, default='cls',
help='What type of feature to extract from the model. If finetuning an ensemble, pass a '
'comma-separated list of features (same length as model_type). Accepted feature types: '
Expand Down Expand Up @@ -183,7 +183,7 @@ def load_lora_weights(self, checkpoint_root, epoch_load=None):
if self.save_mode in {'adapter_only', 'all'}:
if epoch_load is not None:
checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}')

logging.info(f'Loading adapter weights from {checkpoint_root}')
self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device)
else:
Expand Down

0 comments on commit 38b8f37

Please sign in to comment.