Skip to content

Commit

Permalink
eval pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ssundaram21 committed Jun 28, 2024
1 parent 5e6b4a8 commit dd35e81
Show file tree
Hide file tree
Showing 6 changed files with 564 additions and 4 deletions.
17 changes: 17 additions & 0 deletions configs/eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
eval_checkpoint: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/checkpoints/clip_vitb32_lora/"
eval_checkpoint_cfg: "/vision-nfs/isola/projects/shobhita/code/dreamsim/dreamsim_steph/new_checkpoints/lora_single_clip_vitb32_embedding_lora_lr_0.0003_batchsize_32_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_8.0_loradropout_0.3/lightning_logs/version_0/config.yaml"
load_dir: "/vision-nfs/isola/projects/shobhita/code/dreamsim/models"

baseline_model: "clip_vitb32"
baseline_feat_type: "cls"
baseline_stride: "32"

nights_root: "/vision-nfs/isola/projects/shobhita/data/nights"
bapps_root: "/vision-nfs/isola/projects/shobhita/data/2afc/val"
things_root: "/vision-nfs/isola/projects/shobhita/data/things/things_src_images"
things_file: "/vision-nfs/isola/projects/shobhita/data/things/things_valset.txt"
df2_root: "/data/vision/phillipi/perception/data/df2_org3/"
df2_gt: "/data/vision/phillipi/perception/code/repalignment/configs/df2_gt.json"

batch_size: 256
num_workers: 10
134 changes: 134 additions & 0 deletions evaluation/eval_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from torch.utils.data import Dataset
from util.utils import get_preprocess_fn
from torchvision import transforms
import pandas as pd
import numpy as np
from PIL import Image
import os
from typing import Callable
import torch
import glob

IMAGE_EXTENSIONS = ["jpg", "png", "JPEG", "jpeg"]


class ThingsDataset(Dataset):
def __init__(self, root_dir: str, txt_file: str, preprocess: str, load_size: int = 224,
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC):
with open(txt_file, "r") as f:
self.txt = f.readlines()
self.dataset_root = root_dir
self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation)

def __len__(self):
return len(self.txt)

def __getitem__(self, idx):
im_1, im_2, im_3 = self.txt[idx].split()

im_1 = Image.open(os.path.join(self.dataset_root, f"{im_1}.png"))
im_2 = Image.open(os.path.join(self.dataset_root, f"{im_2}.png"))
im_3 = Image.open(os.path.join(self.dataset_root, f"{im_3}.png"))

im_1 = self.preprocess_fn(im_1)
im_2 = self.preprocess_fn(im_2)
im_3 = self.preprocess_fn(im_3)

return im_1, im_2, im_3



class BAPPSDataset(Dataset):
def __init__(self, root_dir: str, preprocess: str, load_size: int = 224,
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC):
data_types = ["cnn", "traditional", "color", "deblur", "superres", "frameinterp"]

self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation)
self.judge_paths = []
self.p0_paths = []
self.p1_paths = []
self.ref_paths = []

for dt in data_types:
list_dir = os.path.join(os.path.join(root_dir, dt), "judge")
for fname in os.scandir(list_dir):
self.judge_paths.append(os.path.join(list_dir, fname.name))
self.p0_paths.append(os.path.join(os.path.join(os.path.join(root_dir, dt), "p0"), fname.name.split(".")[0] + ".png"))
self.p1_paths.append(
os.path.join(os.path.join(os.path.join(root_dir, dt), "p1"), fname.name.split(".")[0] + ".png"))
self.ref_paths.append(
os.path.join(os.path.join(os.path.join(root_dir, dt), "ref"), fname.name.split(".")[0] + ".png"))

def __len__(self):
return len(self.judge_paths)

def __getitem__(self, idx):
judge = np.load(self.judge_paths[idx])
im_left = self.preprocess_fn(Image.open(self.p0_paths[idx]))
im_right = self.preprocess_fn(Image.open(self.p1_paths[idx]))
im_ref = self.preprocess_fn(Image.open(self.ref_paths[idx]))
return im_ref, im_left, im_right, judge

class DF2Dataset(torch.utils.data.Dataset):
def __init__(self, root_dir, split: str, preprocess: str, load_size: int = 224,
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC):

self.preprocess_fn = get_preprocess_fn(preprocess, load_size, interpolation)
# self.preprocess_fn=preprocess
self.paths = get_paths(os.path.join(root_dir, split))

def __len__(self):
return len(self.paths)

def __getitem__(self, idx):
im_path = self.paths[idx]
img = Image.open(im_path)
img = self.preprocess_fn(img)
return img, im_path

def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

def get_paths(path):
all_paths = []
for ext in IMAGE_EXTENSIONS:
all_paths += glob.glob(os.path.join(path, f"**.{ext}"))
return all_paths

# class ImageDataset(torch.utils.data.Dataset):
# def __init__(self, root, class_to_idx, transform=None, ret_path=False):
# """
# :param root: Dataset root. Should follow the structure class1/0.jpg...n.jpg, class2/0.jpg...n.jpg
# :param class_to_idx: dictionary mapping the classnames to integers.
# :param transform:
# :param ret_path: boolean indicating whether to return the image path or not (useful for KNN for plotting nearest neighbors)
# """

# self.transform = transform
# self.label_to_idx = class_to_idx

# self.paths = []
# self.labels = []
# for cls in class_to_idx:
# cls_paths = get_paths(os.path.join(root, cls))
# self.paths += cls_paths
# self.labels += [self.label_to_idx[cls] for _ in cls_paths]

# self.ret_path = ret_path

# def __len__(self):
# return len(self.paths)

# def __getitem__(self, idx):
# im_path, label = self.paths[idx], self.labels[idx]
# img = pil_loader(im_path)

# if self.transform is not None:
# img = self.transform(img)
# if not self.ret_path:
# return img, label
# else:
# return img, label, im_path
149 changes: 149 additions & 0 deletions evaluation/eval_percep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from pytorch_lightning import seed_everything
import torch
from dataset.dataset import TwoAFCDataset
from util.utils import get_preprocess
from torch.utils.data import DataLoader
import os
import yaml
import logging
from training.train import LightningPerceptualModel
from evaluation.score import score_nights_dataset, score_things_dataset, score_bapps_dataset, score_df2_dataset
from evaluation.eval_datasets import ThingsDataset, BAPPSDataset, DF2Dataset
from torchmetrics.functional import structural_similarity_index_measure, peak_signal_noise_ratio
from DISTS_pytorch import DISTS
from dreamsim import PerceptualModel
from tqdm import tqdm
import pickle
import configargparse
from dreamsim import dreamsim
import clip
from torchvision import transforms

log = logging.getLogger("lightning.pytorch")
log.propagate = False
log.setLevel(logging.ERROR)

def parse_args():
parser = configargparse.ArgumentParser()
parser.add_argument('-c', '--config', required=False, is_config_file=True, help='config file path')

## Run options
parser.add_argument('--seed', type=int, default=1234)

## Checkpoint evaluation options
parser.add_argument('--eval_checkpoint', type=str, help="Path to a checkpoint root.")
parser.add_argument('--eval_checkpoint_cfg', type=str, help="Path to checkpoint config.")
parser.add_argument('--load_dir', type=str, default="./models", help='path to pretrained ViT checkpoints.')

## Baseline evaluation options
parser.add_argument('--baseline_model', type=str, default=None)
parser.add_argument('--baseline_feat_type', type=str, default=None)
parser.add_argument('--baseline_stride', type=str, default=None)

## Dataset options
parser.add_argument('--nights_root', type=str, default='./dataset/nights', help='path to nights dataset.')
parser.add_argument('--bapps_root', type=str, default='./dataset/bapps', help='path to bapps images.')
parser.add_argument('--things_root', type=str, default='./dataset/things/things_imgs', help='path to things images.')
parser.add_argument('--things_file', type=str, default='./dataset/things/things_trainset.txt', help='path to things info file.')
parser.add_argument('--df2_root', type=str, default='./dataset/df2', help='path to df2 root.')
parser.add_argument('--df2_gt', type=str, default='./dataset/df2/df2_gt.json', help='path to df2 ground truth json.')
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--batch_size', type=int, default=4, help='dataset batch size.')

return parser.parse_args()

def load_dreamsim_model(args, device="cuda"):
with open(os.path.join(args.eval_checkpoint_cfg), "r") as f:
cfg = yaml.load(f, Loader=yaml.Loader)

model_cfg = vars(cfg)
model_cfg['load_dir'] = args.load_dir
model = LightningPerceptualModel(**model_cfg)
model.load_lora_weights(args.eval_checkpoint)
model = model.perceptual_model.to(device)
preprocess = "DEFAULT"
return model, preprocess


def load_baseline_model(args, device="cuda"):
model = PerceptualModel(model_type=args.baseline_model, feat_type=args.baseline_feat_type, stride=args.baseline_stride, baseline=True, load_dir=args.load_dir)
model = model.to(device)
preprocess = "DEFAULT"
return model, preprocess
# clip_transform = transforms.Compose([
# transforms.Resize((224,224), interpolation=transforms.InterpolationMode.BICUBIC),
# # transforms.CenterCrop(224),
# lambda x: x.convert('RGB'),
# transforms.ToTensor(),
# transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
# ])
# model, preprocess = clip.load("ViT-B/32", device=device)
# model.visual.ln_post = torch.nn.Identity()
# return model, clip_transform

def eval_nights(model, preprocess, device, args):
eval_results = {}
val_dataset = TwoAFCDataset(root_dir=args.nights_root, split="val", preprocess=preprocess)
test_dataset_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_imagenet", preprocess=preprocess)
test_dataset_no_imagenet = TwoAFCDataset(root_dir=args.nights_root, split="test_no_imagenet", preprocess=preprocess)
total_length = len(test_dataset_no_imagenet) + len(test_dataset_imagenet)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
test_imagenet_loader = DataLoader(test_dataset_imagenet, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
test_no_imagenet_loader = DataLoader(test_dataset_no_imagenet, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)

val_score = score_nights_dataset(model, val_loader, device)
imagenet_score = score_nights_dataset(model, test_imagenet_loader, device)
no_imagenet_score = score_nights_dataset(model, test_no_imagenet_loader, device)

eval_results['nights_val'] = val_score.item()
eval_results['nights_imagenet'] = imagenet_score.item()
eval_results['nights_no_imagenet'] = no_imagenet_score.item()
eval_results['nights_total'] = (imagenet_score.item() * len(test_dataset_imagenet) +
no_imagenet_score.item() * len(test_dataset_no_imagenet)) / total_length
logging.info(f"Combined 2AFC score: {str(eval_results['nights_total'])}")
return eval_results

def eval_bapps(model, preprocess, device, args):
test_dataset_bapps = BAPPSDataset(root_dir=args.bapps_root, preprocess=preprocess)
test_loader_bapps = DataLoader(test_dataset_bapps, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
bapps_score = score_bapps_dataset(model, test_loader_bapps, device)
return {"bapps_total": bapps_score}

def eval_things(model, preprocess, device, args):
test_dataset_things = ThingsDataset(root_dir=args.things_root, txt_file=args.things_file, preprocess=preprocess)
test_loader_things = DataLoader(test_dataset_things, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
things_score = score_things_dataset(model, test_loader_things, device)
return {"things_total": things_score}

def eval_df2(model, preprocess, device, args):
train_dataset = DF2Dataset(root_dir=args.df2_root, split="gallery", preprocess=preprocess)
test_dataset = DF2Dataset(root_dir=args.df2_root, split="customer", preprocess=preprocess)
train_loader_df2 = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True)
test_loader_df2 = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,pin_memory=True)
df2_score = score_df2_dataset(model, train_loader_df2, test_loader_df2, args.df2_gt, device)
return {"df2_total": df2_score}

def run(args, device):
logging.basicConfig(filename=os.path.join(args.eval_checkpoint, 'eval.log'), level=logging.INFO, force=True)
seed_everything(args.seed)
g = torch.Generator()
g.manual_seed(args.seed)

eval_model, preprocess = load_dreamsim_model(args)
nights_results = eval_nights(eval_model, preprocess, device, args)
bapps_results = eval_bapps(eval_model, preprocess, device, args)
things_results = eval_things(eval_model, preprocess, device, args)
df2_results = eval_df2(eval_model, preprocess, device, args)

if "baseline_model" in args:
baseline_model, baseline_preprocess = load_baseline_model(args)
nights_results = eval_nights(baseline_model, baseline_preprocess, device, args)
bapps_results = eval_bapps(baseline_model, baseline_preprocess, device, args)
things_results = eval_things(baseline_model, baseline_preprocess, device, args)
df2_results = eval_df2(baseline_model, baseline_preprocess, device, args)

if __name__ == '__main__':
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
run(args, device)

Loading

0 comments on commit dd35e81

Please sign in to comment.