Spaces:
Running
Running
import skimage | |
import torch | |
import numpy as np | |
from pytorch_msssim import ssim | |
import math | |
def calc_metrics(harmonized, real, mask_batch): | |
n, c, h, w = harmonized.shape | |
mse = [] | |
fmse = [] | |
psnr = [] | |
ssim = [] | |
for id in range(n): | |
# fg = (mask_batch[id]).view(-1) | |
# fg_pixels = int(torch.sum(fg).cpu().numpy()) | |
# total_pixels = h * w | |
# | |
# pred = torch.clamp(harmonized[id] * 255, 0, 255) | |
# gt = torch.clamp(real[id] * 255, 0, 255) | |
# | |
# pred = pred.permute(1, 2, 0).cpu().numpy() | |
# gt = gt.permute(1, 2, 0).cpu().numpy() | |
# mask = mask_batch[id].permute(1, 2, 0).cpu().numpy() | |
# | |
# mse.append(skimage.metrics.mean_squared_error(pred, gt)) | |
# fmse.append(skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels) | |
# psnr.append(skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min())) | |
# ssim.append(skimage.metrics.structural_similarity(pred, gt, multichannel=True)) | |
mse.append(MSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
fmse.append(fMSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
psnr.append(PSNR(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
ssim.append(SSIM(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) | |
return mse, fmse, psnr, ssim | |
def SSIM(pred, target_image, mask): | |
pred = pred * mask + (target_image) * (1 - mask) | |
return ssim(pred.unsqueeze(0), target_image.unsqueeze(0)) | |
def MSE(pred, target_image, mask): | |
return (mask * (pred - target_image) ** 2).mean().item() | |
def fMSE(pred, target_image, mask): | |
diff = mask * ((pred - target_image) ** 2) | |
return (diff.sum() / (diff.size(0) * mask.sum() + 1e-6)).item() | |
def PSNR(pred, target_image, mask): | |
mse = (mask * (pred - target_image) ** 2).mean().item() | |
squared_max = target_image.max().item() ** 2 | |
return 10 * math.log10(squared_max / (mse + 1e-6)) |