Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ssundaram21 committed Oct 14, 2024
1 parent 7ef0e90 commit 4c1a5ce
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
8 changes: 4 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
])


def preprocess(path):
pil_img = Image.open(path).convert('RGB')
return t(pil_img).unsqueeze(0)
def preprocess(img):
img = img.convert('RGB')
return t(img).unsqueeze(0)

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, preprocess = dreamsim(pretrained=True, device=device)
model, _ = dreamsim(pretrained=True, device=device)

# Load images
img_ref = preprocess(Image.open('images/ref_1.png')).to(device)
Expand Down
12 changes: 12 additions & 0 deletions dreamsim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@
"model_type": "dinov2_vitb14",
"stride": "14",
"lora": True
},
"dino_vitb16_patch": {
"feat_type": 'cls_patch',
"model_type": "dino_vitb16",
"stride": "16",
"lora": True
},
"dinov2_vitb14_patch": {
"feat_type": 'cls_patch',
"model_type": "dinov2_vitb14",
"stride": "14",
"lora": True
},
"clip_vitb32": {
"feat_type": 'embedding',
Expand Down
36 changes: 23 additions & 13 deletions dreamsim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, model_type: str = "dino_vitb16", feat_type: str = "cls", stri
self.model_list = model_type.split(',')
self.feat_type_list = feat_type.split(',')
self.stride_list = [int(x) for x in stride.split(',')]
self.is_patch = "cls_patch" in self.feat_type_list
self._validate_args()
self.extract_feats_list = []
self.extractor_list = nn.ModuleList()
Expand Down Expand Up @@ -79,7 +80,6 @@ def forward(self, img_a, img_b):
"""
:param img_a: An RGB image passed as a (1, 3, 224, 224) tensor with values [0-1].
:param img_b: Same as img_a.
:param return_patch: If True, returns the distance for cat(CLS, patch)
:return: A distance score for img_a and img_b. Higher means further/more different.
"""
embed_a = self.embed(img_a)
Expand All @@ -94,7 +94,11 @@ def forward(self, img_a, img_b):
n = patch_a.shape[0]
s = int(patch_a.shape[1] ** 0.5)
patch_a_pooled = F.adaptive_avg_pool2d(patch_a.reshape(n, s, s, -1).permute(0, 3, 1, 2), (1, 1)).squeeze()
if len(patch_a_pooled.shape) == 1:
patch_a_pooled = patch_a_pooled.unsqueeze(0)
patch_b_pooled = F.adaptive_avg_pool2d(patch_b.reshape(n, s, s, -1).permute(0, 3, 1, 2), (1, 1)).squeeze()
if len(patch_b_pooled.shape) == 1:
patch_b_pooled = patch_b_pooled.unsqueeze(0)

embed_a = torch.cat((cls_a, patch_a_pooled), dim=-1)
embed_b = torch.cat((cls_b, patch_b_pooled), dim=-1)
Expand All @@ -110,15 +114,26 @@ def embed(self, img):
feats = (self.extract_feats_list[i](img, extractor_index=i)).squeeze()
full_feats = torch.cat((full_feats, feats), dim=-1)
embed = self.mlp(full_feats)

if len(embed.shape) <= 1:
embed = embed.unsqueeze(0)
if len(embed.shape) <= 2 and self.is_patch:
embed = embed.unsqueeze(0)

if self.normalize_embeds:
embed = normalize_embedding(embed)
embed = normalize_embedding_patch(embed) if self.is_patch else normalize_embedding(embed)

return embed

def _validate_args(self):
assert len(self.model_list) == len(self.feat_type_list) == len(self.stride_list)

for model_type, feat_type, stride in zip(self.model_list, self.feat_type_list, self.stride_list):
if feat_type == "embedding" and ("dino" in model_type or "mae" in model_type):
raise ValueError(f"{feat_type} not supported for {model_type}")
if self.is_patch and feat_type != "cls_patch":
# If cls_patch is specified for one model, it has to be specified for all.
raise ValueError(f"Cannot extract {feat_type} for {model_type}; cls_patch specified elsewhere.")

def _get_extract_fn(self, model_type, feat_type):
num_feats = 1
Expand All @@ -128,8 +143,6 @@ def _get_extract_fn(self, model_type, feat_type):
extract_fn = self._extract_embedding
elif feat_type == "last_layer":
extract_fn = self._extract_last_layer
elif feat_type == "patch":
extract_fn = self._extract_patch
elif feat_type == "cls_patch":
extract_fn = self._extract_cls_and_patch
num_feats = 2
Expand All @@ -147,10 +160,6 @@ def _extract_cls_and_patch(self, img, extractor_index=0):
out = self.extractor_list[extractor_index].extract_descriptors(img, layer)
return out

def _extract_patch(self, img, extractor_index=0):
layer = 11
return self._extract_cls_and_patch(img, extractor_index)[:, :, :, 1:, :]

def _extract_cls(self, img, extractor_index=0):
layer = 11
return self._extract_cls_and_patch(img, extractor_index)[:, :, :, 0, :]
Expand Down Expand Up @@ -254,11 +263,10 @@ def dreamsim(pretrained: bool = True, device="cuda", cache_dir="./models", norma
- PerceptualModel with DreamSim settings and weights.
- Preprocessing function that converts a PIL image and to a (1, 3, 224, 224) tensor with values [0-1].
"""
download_key = dreamsim_type
if use_patch_model:
download_key += '_patch'
dreamsim_type += '_patch'
# Get model settings and weights
download_weights(cache_dir=cache_dir, dreamsim_type=download_key)
download_weights(cache_dir=cache_dir, dreamsim_type=dreamsim_type)

# initialize PerceptualModel and load weights
model_list = dreamsim_args['model_config'][dreamsim_type]['model_type'].split(",")
Expand Down Expand Up @@ -299,11 +307,13 @@ def preprocess(pil_img):


def normalize_embedding(embed):
if len(embed.shape) <= 1:
embed = embed.unsqueeze(0)
embed = (embed.T / torch.norm(embed, dim=1)).T
return (embed.T - torch.mean(embed, dim=1)).T

def normalize_embedding_patch(embed):
embed = (embed.mT / torch.norm(embed, dim=2)).mT
return (embed.mT - torch.mean(embed, dim=2)).mT

EMBED_DIMS = {
'dino_vits8': {'cls': 384, 'last_layer': 384},
'dino_vits16': {'cls': 384, 'last_layer': 384},
Expand Down

0 comments on commit 4c1a5ce

Please sign in to comment.