Skip to content

Commit

Permalink
add distillation and new backbones
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanie-fu committed Feb 18, 2024
1 parent cbcffe9 commit 7cfa2a5
Show file tree
Hide file tree
Showing 12 changed files with 503 additions and 18 deletions.
21 changes: 21 additions & 0 deletions configs/distill_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
seed: 1234
tag: lora_distill
log_dir: /home/fus/repos/dreamsim-dev/output/dev

model_type: 'clip_vitb32'
feat_type: 'embedding'
stride: '32'
use_lora: True

dataset_root: ./dataset/nights
num_workers: 4

lr: 0.0003
weight_decay: 0.0
batch_size: 32
epochs: 15
margin: 0.05

lora_r: 8
lora_alpha: 16
lora_dropout: 0
10 changes: 5 additions & 5 deletions configs/train_single_model_lora.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
seed: 1234
tag: lora_single
log_dir: ./output
log_dir: ./output/new_backbones

model_type: 'mae_vitb16'
feat_type: 'last_layer'
model_type: 'synclr_vitb16'
feat_type: 'cls'
stride: '16'
use_lora: True

Expand All @@ -17,5 +17,5 @@ epochs: 8
margin: 0.05

lora_r: 16
lora_alpha: 0.5
lora_dropout: 0.3
lora_alpha: 16
lora_dropout: 0.1
4 changes: 2 additions & 2 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, root_dir: str, split: str = "train", load_size: int = 224,
self.load_size = load_size
self.interpolation = interpolation
self.preprocess_fn = get_preprocess_fn(preprocess, self.load_size, self.interpolation)

if self.split == "train" or self.split == "val":
if self.split == "train" or self.split == "val" or self.split == "test":
self.csv = self.csv[self.csv["split"] == split]
elif split == 'test_imagenet':
self.csv = self.csv[self.csv['split'] == 'test']
Expand Down
4 changes: 2 additions & 2 deletions dataset/download_dataset.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
mkdir -p ./dataset
cd dataset
# mkdir -p ./dataset
cd /home/fus/data/

# Download NIGHTS dataset
wget -O nights.zip https://data.csail.mit.edu/nights/nights.zip
Expand Down
1 change: 1 addition & 0 deletions dataset/nights
11 changes: 10 additions & 1 deletion dreamsim/feature_extraction/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from .load_clip_as_dino import load_clip_as_dino
from .load_open_clip_as_dino import load_open_clip_as_dino
from .load_synclr_as_dino import load_synclr_as_dino
from .vision_transformer import DINOHead
from .load_mae_as_vit import load_mae_as_vit

Expand Down Expand Up @@ -62,7 +63,10 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module:
:param load_dir: location of pretrained ViT checkpoints.
:return: the model
"""
if 'dino' in model_type:
if 'dinov2' in model_type:
torch.hub.set_dir(load_dir)
model = torch.hub.load('facebookresearch/dinov2', model_type)
elif 'dino' in model_type:
torch.hub.set_dir(load_dir)
model = torch.hub.load('facebookresearch/dino:main', model_type)
if model_type == 'dino_vitb16':
Expand Down Expand Up @@ -96,6 +100,11 @@ def create_model(model_type: str, load_dir: str = "./models") -> nn.Module:
raise ValueError(f"Model {model_type} not supported")
elif 'mae' in model_type:
model = load_mae_as_vit(model_type, load_dir)
elif 'synclr' in model_type:
if model_type == 'synclr_vitb16':
model = load_synclr_as_dino(16, load_dir)
else:
raise ValueError(f"Model {model_type} not supported")
else:
raise ValueError(f"Model {model_type} not supported")
return model
Expand Down
16 changes: 16 additions & 0 deletions dreamsim/feature_extraction/load_synclr_as_dino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from .vision_transformer import vit_base, VisionTransformer
import os


def load_synclr_as_dino(patch_size, load_dir="./models", l14=False):
sd = torch.load(os.path.join(load_dir, f'synclr_vit_b_{patch_size}.pth'))['model']
dino_vit = vit_base(patch_size=patch_size)
new_sd = dict()

for k, v in sd.items():
new_key = k[14:] # strip "module.visual" from key
new_sd[new_key] = v

dino_vit.load_state_dict(new_sd)
return dino_vit
9 changes: 7 additions & 2 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def _get_mean(self, model_type):
return (0.48145466, 0.4578275, 0.40821073)
elif "mae" in model_type:
return (0.485, 0.456, 0.406)
elif "synclr" in model_type:
return (0.485, 0.456, 0.406)

def _get_std(self, model_type):
if "dino" in model_type:
Expand All @@ -139,6 +141,8 @@ def _get_std(self, model_type):
return (0.26862954, 0.26130258, 0.27577711)
elif "mae" in model_type:
return (0.229, 0.224, 0.225)
elif "synclr" in model_type:
return (0.229, 0.224, 0.225)


class MLP(torch.nn.Module):
Expand Down Expand Up @@ -252,6 +256,7 @@ def normalize_embedding(embed):
'dino_vits16': {'cls': 384, 'last_layer': 384},
'dino_vitb8': {'cls': 768, 'last_layer': 768},
'dino_vitb16': {'cls': 768, 'last_layer': 768},
'dinov2_vitb14': {'cls': 768, 'last_layer': 768},
'clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 512},
'clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768},
Expand All @@ -260,6 +265,6 @@ def normalize_embedding(embed):
'mae_vith14': {'cls': 1280, 'last_layer': 1280},
'open_clip_vitb16': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitb32': {'cls': 768, 'embedding': 512, 'last_layer': 768},
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768}
'open_clip_vitl14': {'cls': 1024, 'embedding': 768, 'last_layer': 768},
'synclr_vitb16': {'cls': 768, 'last_layer': 768},
}

Loading

0 comments on commit 7cfa2a5

Please sign in to comment.