Skip to content

Commit

Permalink
add new dataset download scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanie-fu committed Jul 16, 2024
1 parent dd35e81 commit 2dc509e
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 47 deletions.
21 changes: 0 additions & 21 deletions configs/distill_lora.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/train_ensemble_model_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ epochs: 6
margin: 0.05

lora_r: 16
lora_alpha: 0.5
lora_alpha: 8
lora_dropout: 0.3
6 changes: 3 additions & 3 deletions configs/train_single_model_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ seed: 1234
tag: lora_single
log_dir: ./output/new_backbones

model_type: 'synclr_vitb16'
model_type: 'dino_vitb16'
feat_type: 'cls'
stride: '16'
use_lora: True
Expand All @@ -17,5 +17,5 @@ epochs: 8
margin: 0.05

lora_r: 16
lora_alpha: 16
lora_dropout: 0.1
lora_alpha: 32
lora_dropout: 0.2
14 changes: 14 additions & 0 deletions dataset/download_chunked_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset

mkdir -p ref
mkdir -p distort

# Download NIGHTS dataset
for i in $(seq -f "%03g" 0 99); do
wget https://data.csail.mit.edu/nights/nights_chunked/ref/$i.zip
unzip -q $i.zip -d ref/ && rm $i.zip
wget https://data.csail.mit.edu/nights/nights_chunked/distort/$i.zip
unzip -q $i.zip -d distort/ && rm $i.zip
done
6 changes: 6 additions & 0 deletions dataset/download_jnd_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset

# Download JND data for NIGHTS dataset
wget https://data.csail.mit.edu/nights/jnd_votes.csv
20 changes: 20 additions & 0 deletions dataset/download_unfiltered_dataset.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash
mkdir -p ./dataset_unfiltered
cd dataset_unfiltered

mkdir -p ref
mkdir -p distort

# Download NIGHTS dataset

# store those in a list and loop through wget and unzip and rm
for i in {0..99..25}
do
wget https://data.csail.mit.edu/nights/nights_unfiltered/ref_${i}_$(($i+24)).zip
unzip -q ref_${i}_$(($i+24)).zip -d ref
rm ref_*.zip

wget https://data.csail.mit.edu/nights/nights_unfiltered/distort_${i}_$(($i+24)).zip
unzip -q distort_${i}_$(($i+24)).zip -d distort
rm distort_*.zip
done
7 changes: 0 additions & 7 deletions dreamsim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
"lora": True
}
},
"lora_config": {
"r": 16,
"lora_alpha": 0.5,
"lora_dropout": 0.3,
"bias": "none",
"target_modules": ['qkv']
},
"img_size": 224
}

Expand Down
12 changes: 8 additions & 4 deletions dreamsim/feature_extraction/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# This version was taken from https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# On Jan 24th, 2022
# Git hash of last commit: 4b96393c4c877d127cff9f077468e4a1cc2b5e2d

"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
Expand All @@ -27,7 +27,7 @@
import torch.nn as nn

trunc_normal_ = lambda *args, **kwargs: None


def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
Expand All @@ -43,6 +43,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
Expand Down Expand Up @@ -121,6 +122,7 @@ def forward(self, x, return_attention=False):
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
Expand All @@ -138,6 +140,7 @@ def forward(self, x):

class VisionTransformer(nn.Module):
""" Vision Transformer """

def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
Expand Down Expand Up @@ -239,7 +242,8 @@ def get_intermediate_layers(self, x, n=1):


class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
Expand Down Expand Up @@ -273,7 +277,7 @@ def forward(self, x):
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x


def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(
Expand Down
27 changes: 19 additions & 8 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import torch
import torch.nn.functional as F
from torch import nn
Expand All @@ -7,11 +9,19 @@
from util.constants import *
from .feature_extraction.extractor import ViTExtractor
import yaml
import peft
from peft import PeftModel, LoraConfig, get_peft_model
from .config import dreamsim_args, dreamsim_weights
import os
import zipfile
from packaging import version

peft_version = version.parse(peft.__version__)
min_version = version.parse('0.2.0')

if peft_version < min_version:
raise RuntimeError(f"DreamSim requires peft version {min_version} or greater. "
"Please update peft with 'pip install --upgrade peft'.")

class PerceptualModel(torch.nn.Module):
def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stride: str = '16', hidden_size: int = 1,
Expand Down Expand Up @@ -165,9 +175,8 @@ def download_weights(cache_dir, dreamsim_type):
"""

dreamsim_required_ckpts = {
"ensemble": ["dino_vitb16_pretrain.pth", "dino_vitb16_lora",
"open_clip_vitb16_pretrain.pth.tar", "open_clip_vitb16_lora",
"clip_vitb16_pretrain.pth.tar", "clip_vitb16_lora"],
"ensemble": ["dino_vitb16_pretrain.pth", "open_clip_vitb16_pretrain.pth.tar",
"clip_vitb16_pretrain.pth.tar", "ensemble_lora"],
"dino_vitb16": ["dino_vitb16_pretrain.pth", "dino_vitb16_single_lora"],
"open_clip_vitb32": ["open_clip_vitb32_pretrain.pth.tar", "open_clip_vitb32_single_lora"],
"clip_vitb32": ["clip_vitb32_pretrain.pth.tar", "clip_vitb32_single_lora"]
Expand Down Expand Up @@ -216,10 +225,14 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma
ours_model = PerceptualModel(**dreamsim_args['model_config'][dreamsim_type], device=device, load_dir=cache_dir,
normalize_embeds=normalize_embeds)

lora_config = LoraConfig(**dreamsim_args['lora_config'])
tag = "ensemble_" if dreamsim_type == "ensemble" else f"{model_list[0]}_single_"

with open(os.path.join(cache_dir, f'{tag}lora', 'adapter_config.json'), 'r') as f:
adapter_config = json.load(f)
lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules']
lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys})
ours_model = get_peft_model(ours_model, lora_config)

tag = "" if dreamsim_type == "ensemble" else f"single_{model_list[0]}"
if pretrained:
load_dir = os.path.join(cache_dir, f"{tag}lora")
ours_model = PeftModel.from_pretrained(ours_model.base_model.model, load_dir).to(device)
Expand Down Expand Up @@ -251,7 +264,6 @@ def normalize_embedding(embed):
'dino_vits16': {'cls': 384, 'last_layer': 384},
'dino_vitb8': {'cls': 768, 'last_layer': 768},
'dino_vitb16': {'cls': 768, 'last_layer': 768},
'dinov2_vitb14': {'cls': 768, 'last_layer': 768},
'clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 512},
'clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768},
Expand All @@ -260,6 +272,5 @@ def normalize_embedding(embed):
'mae_vith14': {'cls': 1280, 'last_layer': 1280},
'open_clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768},
'synclr_vitb16': {'cls': 768, 'last_layer': 768},
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ lpips
numpy
open-clip-torch
pandas
peft>=0.4.0
peft>=0.2.0
Pillow
pytorch-lightning
PyYAML
Expand Down
4 changes: 2 additions & 2 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_args():
help='Which ViT model to finetune. To finetune an ensemble of models, pass a comma-separated'
'list of models. Accepted models: [dino_vits8, dino_vits16, dino_vitb8, dino_vitb16, '
'clip_vitb16, clip_vitb32, clip_vitl14, mae_vitb16, mae_vitl16, mae_vith14, '
'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14, dinov2_vitb14, synclr_vitb16]')
'open_clip_vitb16, open_clip_vitb32, open_clip_vitl14]')
parser.add_argument('--feat_type', type=str, default='cls',
help='What type of feature to extract from the model. If finetuning an ensemble, pass a '
'comma-separated list of features (same length as model_type). Accepted feature types: '
Expand Down Expand Up @@ -183,7 +183,7 @@ def load_lora_weights(self, checkpoint_root, epoch_load=None):
if self.save_mode in {'adapter_only', 'all'}:
if epoch_load is not None:
checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}')

logging.info(f'Loading adapter weights from {checkpoint_root}')
self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device)
else:
Expand Down

0 comments on commit 2dc509e

Please sign in to comment.