import os import torch from cleanfid import fid as FID from PIL import Image from torch.utils.data import Dataset from torchmetrics.image import StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchvision import transforms from tqdm import tqdm from utils import scan_files_in_dir from prettytable import PrettyTable class EvalDataset(Dataset): def __init__(self, gt_folder, pred_folder, height=1024): self.gt_folder = gt_folder self.pred_folder = pred_folder self.height = height self.data = self.prepare_data() self.to_tensor = transforms.ToTensor() def extract_id_from_filename(self, filename): # find first number in filename start_i = None for i, c in enumerate(filename): if c.isdigit(): start_i = i break if start_i is None: assert False, f"Cannot find number in filename {filename}" return filename[start_i:start_i+8] def prepare_data(self): gt_files = scan_files_in_dir(self.gt_folder, postfix={'.jpg', '.png'}) gt_dict = {self.extract_id_from_filename(file.name): file for file in gt_files} pred_files = scan_files_in_dir(self.pred_folder, postfix={'.jpg', '.png'}) tuples = [] for pred_file in pred_files: pred_id = self.extract_id_from_filename(pred_file.name) if pred_id not in gt_dict: print(f"Cannot find gt file for {pred_file}") else: tuples.append((gt_dict[pred_id].path, pred_file.path)) return tuples def resize(self, img): w, h = img.size new_w = int(w * self.height / h) return img.resize((new_w, self.height), Image.LANCZOS) def __len__(self): return len(self.data) def __getitem__(self, idx): gt_path, pred_path = self.data[idx] gt, pred = self.resize(Image.open(gt_path)), self.resize(Image.open(pred_path)) if gt.height != self.height: gt = self.resize(gt) if pred.height != self.height: pred = self.resize(pred) gt = self.to_tensor(gt) pred = self.to_tensor(pred) return gt, pred def copy_resize_gt(gt_folder, height): new_folder = f"{gt_folder}_{height}" if not os.path.exists(new_folder): os.makedirs(new_folder, exist_ok=True) for file in tqdm(os.listdir(gt_folder)): if os.path.exists(os.path.join(new_folder, file)): continue img = Image.open(os.path.join(gt_folder, file)) w, h = img.size new_w = int(w * height / h) img = img.resize((new_w, height), Image.LANCZOS) img.save(os.path.join(new_folder, file)) return new_folder @torch.no_grad() def ssim(dataloader): ssim_score = 0 # ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to("cuda") ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to("cpu") for gt, pred in tqdm(dataloader, desc="Calculating SSIM"): batch_size = gt.size(0) # gt, pred = gt.to("cuda"), pred.to("cuda") gt, pred = gt.to("cpu"), pred.to("cpu") ssim_score += ssim(pred, gt) * batch_size return ssim_score / len(dataloader.dataset) @torch.no_grad() def lpips(dataloader): # lpips_score = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to("cuda") lpips_score = LearnedPerceptualImagePatchSimilarity(net_type='squeeze').to("cpu") score = 0 for gt, pred in tqdm(dataloader, desc="Calculating LPIPS"): batch_size = gt.size(0) # pred = pred.to("cuda") pred = pred.to("cpu") # gt = gt.to("cuda") gt = gt.to("cpu") # LPIPS needs the images to be in the [-1, 1] range. gt = (gt * 2) - 1 pred = (pred * 2) - 1 score += lpips_score(gt, pred) * batch_size return score / len(dataloader.dataset) def eval(args): # Check gt_folder has images with target height, resize if not pred_sample = os.listdir(args.pred_folder)[0] gt_sample = os.listdir(args.gt_folder)[0] img = Image.open(os.path.join(args.pred_folder, pred_sample)) gt_img = Image.open(os.path.join(args.gt_folder, gt_sample)) if img.height != gt_img.height: title = "--"*30 + "Resizing GT Images to height {img.height}" + "--"*30 print(title) args.gt_folder = copy_resize_gt(args.gt_folder, img.height) print("-"*len(title)) # Form dataset dataset = EvalDataset(args.gt_folder, args.pred_folder, img.height) dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, drop_last=False ) # Calculate Metrics header = [] row = [] header = ["FID", "KID"] fid_ = FID.compute_fid(args.gt_folder, args.pred_folder) kid_ = FID.compute_kid(args.gt_folder, args.pred_folder) * 1000 row = [fid_, kid_] if args.paired: header += ["SSIM", "LPIPS"] ssim_ = ssim(dataloader).item() lpips_ = lpips(dataloader).item() row += [ssim_, lpips_] # Print Results print("GT Folder : ", args.gt_folder) print("Pred Folder: ", args.pred_folder) table = PrettyTable() table.field_names = header table.add_row(row) print(table) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--gt_folder", type=str, required=True) parser.add_argument("--pred_folder", type=str, required=True) parser.add_argument("--paired", action="store_true") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_workers", type=int, default=4) args = parser.parse_args() eval(args)