-
Notifications
You must be signed in to change notification settings - Fork 22
/
demo.py
38 lines (28 loc) · 938 Bytes
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from PIL import Image
from dreamsim import dreamsim
from torchvision import transforms
import torch
import os
img_size = 224
t = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
def preprocess(path):
pil_img = Image.open(path).convert('RGB')
return t(pil_img).unsqueeze(0)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, preprocess = dreamsim(pretrained=True, device=device)
# Load images
img_ref = preprocess(Image.open('images/ref_1.png')).to(device)
img_0 = preprocess(Image.open('images/img_a_1.png')).to(device)
img_1 = preprocess(Image.open('images/img_b_1.png')).to(device)
# Get distance
d0 = model(img_ref, img_0)
d1 = model(img_ref, img_1)
print(d0, d1)
# # Get embeddings
# embed_ref = model.embed(img_ref)
# embed_0 = model.embed(img_0)
# embed_1 = model.embed(img_1)