Skip to content

Commit

Permalink
fix zero trainable parameter bug
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanie-fu committed May 29, 2024
1 parent bdbfeb0 commit 4174cfc
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 21 deletions.
3 changes: 2 additions & 1 deletion dreamsim/feature_extraction/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""


class ViTExtractor:
class ViTExtractor(nn.Module):
""" This class facilitates extraction of features, descriptors, and saliency maps from a ViT.
We use the following notation in the documentation of the module's methods:
Expand All @@ -38,6 +38,7 @@ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, load_dir: st
:param stride: stride of first convolution layer. small stride -> higher resolution.
:param load_dir: location of pretrained ViT checkpoints.
"""
super(ViTExtractor, self).__init__()
self.model_type = model_type
self.device = device
self.model = ViTExtractor.create_model(model_type, load_dir)
Expand Down
25 changes: 12 additions & 13 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
import os

from util.constants import *
from .feature_extraction.extractor import ViTExtractor
import yaml
from peft import PeftModel, LoraConfig, get_peft_model
Expand Down Expand Up @@ -41,7 +44,7 @@ def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stri
self.stride_list = [int(x) for x in stride.split(',')]
self._validate_args()
self.extract_feats_list = []
self.extractor_list = []
self.extractor_list = nn.ModuleList()
self.embed_size = 0
self.hidden_size = hidden_size
self.baseline = baseline
Expand Down Expand Up @@ -122,27 +125,23 @@ def _preprocess(self, img, model_type):

def _get_mean(self, model_type):
if "dino" in model_type:
return (0.485, 0.456, 0.406)
return IMAGENET_DEFAULT_MEAN
elif "open_clip" in model_type:
return (0.48145466, 0.4578275, 0.40821073)
return OPENAI_CLIP_MEAN
elif "clip" in model_type:
return (0.48145466, 0.4578275, 0.40821073)
return OPENAI_CLIP_MEAN
elif "mae" in model_type:
return (0.485, 0.456, 0.406)
elif "synclr" in model_type:
return (0.485, 0.456, 0.406)
return IMAGENET_DEFAULT_MEAN

def _get_std(self, model_type):
if "dino" in model_type:
return (0.229, 0.224, 0.225)
return IMAGENET_DEFAULT_STD
elif "open_clip" in model_type:
return (0.26862954, 0.26130258, 0.27577711)
return OPENAI_CLIP_STD
elif "clip" in model_type:
return (0.26862954, 0.26130258, 0.27577711)
return OPENAI_CLIP_STD
elif "mae" in model_type:
return (0.229, 0.224, 0.225)
elif "synclr" in model_type:
return (0.229, 0.224, 0.225)
return IMAGENET_DEFAULT_STD


class MLP(torch.nn.Module):
Expand Down
9 changes: 2 additions & 7 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from dreamsim.feature_extraction.vit_wrapper import ViTModel, ViTConfig
import os
import configargparse
from tqdm import tqdm

# log = logging.getLogger("lightning.pytorch")
# log.propagate = False
# log.setLevel(logging.INFO)


def parse_args():
Expand Down Expand Up @@ -103,8 +98,8 @@ def __init__(self, feat_type: str = "cls", model_type: str = "dino_vitb16", stri

pytorch_total_params = sum(p.numel() for p in self.perceptual_model.parameters())
pytorch_total_trainable_params = sum(p.numel() for p in self.perceptual_model.parameters() if p.requires_grad)
print(pytorch_total_params)
print(pytorch_total_trainable_params)
print(f'Total params: {pytorch_total_params} | Trainable params: {pytorch_total_trainable_params} '
f'| % Trainable: {pytorch_total_trainable_params/pytorch_total_params}')

self.criterion = HingeLoss(margin=self.margin, device=device)

Expand Down
4 changes: 4 additions & 0 deletions training/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python -m training.train --config configs/train_single_model_lora.yaml --model_type dino_vitb16 --feat_type 'cls' --stride '16' &
CUDA_VISIBLE_DEVICES=1 python -m training.train --config configs/train_single_model_lora.yaml --model_type clip_vitb32 --feat_type 'embedding' --stride '32' &
CUDA_VISIBLE_DEVICES=2 python -m training.train --config configs/train_single_model_lora.yaml --model_type open_clip_vitb32 --feat_type 'embedding' --stride '32' &
CUDA_VISIBLE_DEVICES=3 python -m training.train --config configs/train_ensemble_model_lora.yaml &
5 changes: 5 additions & 0 deletions util/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# use timm names from https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/constants.py
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)

0 comments on commit 4174cfc

Please sign in to comment.