Skip to content

Commit

Permalink
release code updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ssundaram21 committed Sep 13, 2024
1 parent 99222ad commit 91a4749
Show file tree
Hide file tree
Showing 24 changed files with 503 additions and 144 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ img2 = preprocess(Image.open("img2_path")).to("cuda")
distance = model(img1, img2) # The model takes an RGB image from [0, 1], size batch_sizex3x224x224
```

To run on example images, run `demo.py`. The script should produce distances (0.424, 0.34).
To run on example images, run `demo.py`. The script should produce distances (0.4453, 0.2756).

### (new!) Single-branch models
By default, DreamSim uses an ensemble of CLIP, DINO, and OpenCLIP (all ViT-B/16). If you need a lighter-weight model you can use *single-branch* versions of DreamSim where only a single backbone is finetuned. The available options are OpenCLIP-ViTB/32, DINO-ViTB/16, CLIP-ViTB/32, in order of performance.
Expand Down
8 changes: 0 additions & 8 deletions configs/eval_baseline.yaml

This file was deleted.

6 changes: 0 additions & 6 deletions configs/eval_checkpoint.yaml

This file was deleted.

17 changes: 17 additions & 0 deletions configs/eval_ensemble.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
tag: "open_clip"

eval_checkpoint: "/path-to-ckpt/lightning_logs/version_0/checkpoints/epoch-to-eval/"
eval_checkpoint_cfg: "/path-to-ckpt/lightning_logs/version_0/config.yaml"
load_dir: "./models"

baseline_model: "dino_vitb16,clip_vitb16,open_clip_vitb16"
baseline_feat_type: "cls,embedding,embedding"
baseline_stride: "16,16,16"

nights_root: "./data/nights"
bapps_root: "./data/2afc/val"
things_root: "./data/things/things_src_images"
things_file: "./data/things/things_valset.txt"

batch_size: 256
num_workers: 10
17 changes: 17 additions & 0 deletions configs/eval_single_clip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
tag: "clip"

eval_checkpoint: "/path-to-ckpt/lightning_logs/version_0/checkpoints/epoch-to-eval/"
eval_checkpoint_cfg: "/path-to-ckpt/lightning_logs/version_0/config.yaml"
load_dir: "./models"

baseline_model: "clip_vitb32"
baseline_feat_type: "cls"
baseline_stride: "32"

nights_root: "./data/nights"
bapps_root: "./data/2afc/val"
things_root: "./data/things/things_src_images"
things_file: "./data/things/things_valset.txt"

batch_size: 256
num_workers: 10
2 changes: 1 addition & 1 deletion configs/train_ensemble_model_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ epochs: 6
margin: 0.05

lora_r: 16
lora_alpha: 0.5
lora_alpha: 8
lora_dropout: 0.3
10 changes: 5 additions & 5 deletions configs/train_single_model_lora.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
seed: 1234
tag: lora_single
log_dir: ./output
log_dir: ./output/new_backbones

model_type: 'mae_vitb16'
feat_type: 'last_layer'
model_type: 'dino_vitb16'
feat_type: 'cls'
stride: '16'
use_lora: True

Expand All @@ -17,5 +17,5 @@ epochs: 8
margin: 0.05

lora_r: 16
lora_alpha: 0.5
lora_dropout: 0.3
lora_alpha: 32
lora_dropout: 0.2
4 changes: 2 additions & 2 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, root_dir: str, split: str = "train", load_size: int = 224,
self.load_size = load_size
self.interpolation = interpolation
self.preprocess_fn = get_preprocess_fn(preprocess, self.load_size, self.interpolation)

if self.split == "train" or self.split == "val":
if self.split == "train" or self.split == "val" or self.split == "test":
self.csv = self.csv[self.csv["split"] == split]
elif split == 'test_imagenet':
self.csv = self.csv[self.csv['split'] == 'test']
Expand Down
14 changes: 14 additions & 0 deletions dataset/download_chunked_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset

mkdir -p ref
mkdir -p distort

# Download NIGHTS dataset
for i in $(seq -f "%03g" 0 99); do
wget https://data.csail.mit.edu/nights/nights_chunked/ref/$i.zip
unzip -q $i.zip -d ref/ && rm $i.zip
wget https://data.csail.mit.edu/nights/nights_chunked/distort/$i.zip
unzip -q $i.zip -d distort/ && rm $i.zip
done
4 changes: 2 additions & 2 deletions dataset/download_dataset.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset
# mkdir -p ./dataset
cd /home/fus/data/

# Download NIGHTS dataset
wget -O nights.zip https://data.csail.mit.edu/nights/nights.zip
Expand Down
6 changes: 6 additions & 0 deletions dataset/download_jnd_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset

# Download JND data for NIGHTS dataset
wget https://data.csail.mit.edu/nights/jnd_votes.csv
20 changes: 20 additions & 0 deletions dataset/download_unfiltered_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash
mkdir -p ./dataset_unfiltered
cd dataset_unfiltered

mkdir -p ref
mkdir -p distort

# Download NIGHTS dataset

# store those in a list and loop through wget and unzip and rm
for i in {0..99..25}
do
wget https://data.csail.mit.edu/nights/nights_unfiltered/ref_${i}_$(($i+24)).zip
unzip -q ref_${i}_$(($i+24)).zip -d ref
rm ref_*.zip

wget https://data.csail.mit.edu/nights/nights_unfiltered/distort_${i}_$(($i+24)).zip
unzip -q distort_${i}_$(($i+24)).zip -d distort
rm distort_*.zip
done
15 changes: 4 additions & 11 deletions dreamsim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,12 @@
"lora": True
}
},
"lora_config": {
"r": 16,
"lora_alpha": 0.5,
"lora_dropout": 0.3,
"bias": "none",
"target_modules": ['qkv']
},
"img_size": 224
}

dreamsim_weights = {
"ensemble": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/dreamsim_checkpoint.zip",
"dino_vitb16": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.2/dreamsim_dino_vitb16_checkpoint.zip",
"clip_vitb32": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.2/dreamsim_clip_vitb32_checkpoint.zip",
"open_clip_vitb32": "https://github.com/ssundaram21/dreamsim/releases/download/v0.1.2/dreamsim_open_clip_vitb32_checkpoint.zip"
"ensemble": "https://github.com/ssundaram21/dreamsim/releases/download/v0.2.0/dreamsim_checkpoint.zip",
"dino_vitb16": "https://github.com/ssundaram21/dreamsim/releases/download/v0.2.0/dreamsim_dino_vitb16_checkpoint.zip",
"clip_vitb32": "https://github.com/ssundaram21/dreamsim/releases/download/v0.2.0/dreamsim_clip_vitb32_checkpoint.zip",
"open_clip_vitb32": "https://github.com/ssundaram21/dreamsim/releases/download/v0.2.0/dreamsim_open_clip_vitb32_checkpoint.zip"
}
7 changes: 4 additions & 3 deletions dreamsim/feature_extraction/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""


class ViTExtractor:
class ViTExtractor(nn.Module):
""" This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
We use the following notation in the documentation of the module's methods:
Expand All @@ -27,7 +27,7 @@ class ViTExtractor:
d - the embedding dimension in the ViT.
"""

def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, load_dir: str = "./models",
def __init__(self, model_type: str = 'dino_vitb16', stride: int = 4, load_dir: str = "./models",
device: str = 'cuda'):
"""
:param model_type: A string specifying the type of model to extract from.
Expand All @@ -37,6 +37,7 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, load_dir: st
:param stride: stride of first convolution layer. small stride -> higher resolution.
:param load_dir: location of pretrained ViT checkpoints.
"""
super(ViTExtractor, self).__init__()
self.model_type = model_type
self.device = device
self.model = ViTExtractor.create_model(model_type, load_dir)
Expand Down Expand Up @@ -66,7 +67,7 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module:
torch.hub.set_dir(load_dir)
model = torch.hub.load('facebookresearch/dino:main', model_type)
if model_type == 'dino_vitb16':
sd = torch.load(os.path.join(load_dir, 'dino_vitb16_pretrain.pth'), map_location='cpu')
sd = torch.load(os.path.join(load_dir, 'dino_vitb16_pretrain.pth'), map_location='cpu', weights_only=True)
proj = DINOHead(768, 2048)
proj.mlp[0].weight.data = sd['student']['module.head.mlp.0.weight']
proj.mlp[0].bias.data = sd['student']['module.head.mlp.0.bias']
Expand Down
4 changes: 2 additions & 2 deletions dreamsim/feature_extraction/load_clip_as_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def forward(self, x: torch.Tensor):

def load_clip_as_dino(patch_size, load_dir="./models", l14=False):
if l14:
sd = torch.load(os.path.join(load_dir, 'clipl14_as_dino_vitl.pth.tar'), map_location='cpu')
sd = torch.load(os.path.join(load_dir, 'clipl14_as_dino_vitl.pth.tar'), map_location='cpu', weights_only=True)
dino_vit = VisionTransformer(**sd['kwargs'])
sd = sd['state_dict']
else:
sd = torch.load(os.path.join(load_dir, f'clip_vitb{patch_size}_pretrain.pth.tar'))['state_dict']
sd = torch.load(os.path.join(load_dir, f'clip_vitb{patch_size}_pretrain.pth.tar'), weights_only=True)['state_dict']
dino_vit = vit_base(patch_size=patch_size)

dino_vit.pos_drop = torch.nn.LayerNorm(dino_vit.embed_dim)
Expand Down
4 changes: 2 additions & 2 deletions dreamsim/feature_extraction/load_open_clip_as_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

def load_open_clip_as_dino(patch_size, load_dir="./models", l14=False):
if l14:
sd = torch.load(os.path.join(load_dir, 'open_clipl14_as_dino_vitl.pth.tar'), map_location='cpu')
sd = torch.load(os.path.join(load_dir, 'open_clipl14_as_dino_vitl.pth.tar'), map_location='cpu', weights_only=True)
dino_vit = VisionTransformer(**sd['kwargs'])
sd = sd['state_dict']
else:
dino_vit = vit_base(patch_size=patch_size)
sd = torch.load(os.path.join(load_dir, f'open_clip_vitb{patch_size}_pretrain.pth.tar'))['state_dict']
sd = torch.load(os.path.join(load_dir, f'open_clip_vitb{patch_size}_pretrain.pth.tar'), weights_only=True)['state_dict']

dino_vit.pos_drop = torch.nn.LayerNorm(dino_vit.embed_dim)
proj = sd.pop('proj')
Expand Down
19 changes: 0 additions & 19 deletions dreamsim/feature_extraction/vit_wrapper.py

This file was deleted.

58 changes: 34 additions & 24 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import json

import torch
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
import os

from util.constants import *
from .feature_extraction.extractor import ViTExtractor
import yaml
import peft
from peft import PeftModel, LoraConfig, get_peft_model
from .feature_extraction.vit_wrapper import ViTConfig, ViTModel
from .config import dreamsim_args, dreamsim_weights
import os
import zipfile
from packaging import version

peft_version = version.parse(peft.__version__)
min_version = version.parse('0.2.0')

if peft_version < min_version:
raise RuntimeError(f"DreamSim requires peft version {min_version} or greater. "
"Please update peft with 'pip install --upgrade peft'.")

class PerceptualModel(torch.nn.Module):
def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stride: str = '16', hidden_size: int = 1,
Expand Down Expand Up @@ -41,7 +53,7 @@ def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stri
self.stride_list = [int(x) for x in stride.split(',')]
self._validate_args()
self.extract_feats_list = []
self.extractor_list = []
self.extractor_list = nn.ModuleList()
self.embed_size = 0
self.hidden_size = hidden_size
self.baseline = baseline
Expand Down Expand Up @@ -122,23 +134,23 @@ def _preprocess(self, img, model_type):

def _get_mean(self, model_type):
if "dino" in model_type:
return (0.485, 0.456, 0.406)
return IMAGENET_DEFAULT_MEAN
elif "open_clip" in model_type:
return (0.48145466, 0.4578275, 0.40821073)
return OPENAI_CLIP_MEAN
elif "clip" in model_type:
return (0.48145466, 0.4578275, 0.40821073)
return OPENAI_CLIP_MEAN
elif "mae" in model_type:
return (0.485, 0.456, 0.406)
return IMAGENET_DEFAULT_MEAN

def _get_std(self, model_type):
if "dino" in model_type:
return (0.229, 0.224, 0.225)
return IMAGENET_DEFAULT_STD
elif "open_clip" in model_type:
return (0.26862954, 0.26130258, 0.27577711)
return OPENAI_CLIP_STD
elif "clip" in model_type:
return (0.26862954, 0.26130258, 0.27577711)
return OPENAI_CLIP_STD
elif "mae" in model_type:
return (0.229, 0.224, 0.225)
return IMAGENET_DEFAULT_STD


class MLP(torch.nn.Module):
Expand All @@ -161,11 +173,9 @@ def download_weights(cache_dir, dreamsim_type):
"""
Downloads and unzips DreamSim weights.
"""

dreamsim_required_ckpts = {
"ensemble": ["dino_vitb16_pretrain.pth", "dino_vitb16_lora",
"open_clip_vitb16_pretrain.pth.tar", "open_clip_vitb16_lora",
"clip_vitb16_pretrain.pth.tar", "clip_vitb16_lora"],
"ensemble": ["dino_vitb16_pretrain.pth", "open_clip_vitb16_pretrain.pth.tar",
"clip_vitb16_pretrain.pth.tar", "ensemble_lora"],
"dino_vitb16": ["dino_vitb16_pretrain.pth", "dino_vitb16_single_lora"],
"open_clip_vitb32": ["open_clip_vitb32_pretrain.pth.tar", "open_clip_vitb32_single_lora"],
"clip_vitb32": ["clip_vitb32_pretrain.pth.tar", "clip_vitb32_single_lora"]
Expand Down Expand Up @@ -213,17 +223,18 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma
model_list = dreamsim_args['model_config'][dreamsim_type]['model_type'].split(",")
ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir,
normalize_embeds=normalize_embeds)
for extractor in ours_model.extractor_list:
lora_config = LoraConfig(**dreamsim_args['lora_config'])
model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config)
extractor.model = model

tag = "" if dreamsim_type == "ensemble" else "single_"
tag = "ensemble_" if dreamsim_type == "ensemble" else f"{model_list[0]}_single_"

with open(os.path.join(cache_dir, f'{tag}lora', 'adapter_config.json'), 'r') as f:
adapter_config = json.load(f)
lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules']
lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys})
ours_model = get_peft_model(ours_model, lora_config)

if pretrained:
for extractor, name in zip(ours_model.extractor_list, model_list):
load_dir = os.path.join(cache_dir, f"{name}_{tag}lora")
extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device)
extractor.model.eval().requires_grad_(False)
load_dir = os.path.join(cache_dir, f"{tag}lora")
ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device)

ours_model.eval().requires_grad_(False)

Expand Down Expand Up @@ -262,4 +273,3 @@ def normalize_embedding(embed):
'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}
}

Loading

0 comments on commit 91a4749

Please sign in to comment.