|
from argparse import ArgumentParser |
|
import os |
|
import json |
|
import sys |
|
from tqdm import tqdm |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader |
|
import torchvision.transforms as transforms |
|
|
|
sys.path.append(".") |
|
sys.path.append("..") |
|
|
|
from criteria.lpips.lpips import LPIPS |
|
from datasets.gt_res_dataset import GTResDataset |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser(add_help=False) |
|
parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) |
|
parser.add_argument('--data_path', type=str, default='results') |
|
parser.add_argument('--gt_path', type=str, default='gt_images') |
|
parser.add_argument('--workers', type=int, default=4) |
|
parser.add_argument('--batch_size', type=int, default=4) |
|
parser.add_argument('--is_cars', action='store_true') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def run(args): |
|
resize_dims = (256, 256) |
|
if args.is_cars: |
|
resize_dims = (192, 256) |
|
transform = transforms.Compose([transforms.Resize(resize_dims), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
|
print('Loading dataset') |
|
dataset = GTResDataset(root_path=args.data_path, |
|
gt_dir=args.gt_path, |
|
transform=transform) |
|
|
|
dataloader = DataLoader(dataset, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=int(args.workers), |
|
drop_last=True) |
|
|
|
if args.mode == 'lpips': |
|
loss_func = LPIPS(net_type='alex') |
|
elif args.mode == 'l2': |
|
loss_func = torch.nn.MSELoss() |
|
else: |
|
raise Exception('Not a valid mode!') |
|
loss_func.cuda() |
|
|
|
global_i = 0 |
|
scores_dict = {} |
|
all_scores = [] |
|
for result_batch, gt_batch in tqdm(dataloader): |
|
for i in range(args.batch_size): |
|
loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) |
|
all_scores.append(loss) |
|
im_path = dataset.pairs[global_i][0] |
|
scores_dict[os.path.basename(im_path)] = loss |
|
global_i += 1 |
|
|
|
all_scores = list(scores_dict.values()) |
|
mean = np.mean(all_scores) |
|
std = np.std(all_scores) |
|
result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) |
|
print('Finished with ', args.data_path) |
|
print(result_str) |
|
|
|
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') |
|
if not os.path.exists(out_path): |
|
os.makedirs(out_path) |
|
|
|
with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: |
|
f.write(result_str) |
|
with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: |
|
json.dump(scores_dict, f) |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
run(args) |
|
|