File size: 2,189 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import skimage

import torch
import numpy as np
from pytorch_msssim import ssim
import math


def calc_metrics(harmonized, real, mask_batch):
    n, c, h, w = harmonized.shape

    mse = []
    fmse = []
    psnr = []
    ssim = []
    for id in range(n):
        # fg = (mask_batch[id]).view(-1)
        # fg_pixels = int(torch.sum(fg).cpu().numpy())
        # total_pixels = h * w
        #
        # pred = torch.clamp(harmonized[id] * 255, 0, 255)
        # gt = torch.clamp(real[id] * 255, 0, 255)
        #
        # pred = pred.permute(1, 2, 0).cpu().numpy()
        # gt = gt.permute(1, 2, 0).cpu().numpy()
        # mask = mask_batch[id].permute(1, 2, 0).cpu().numpy()
        #
        # mse.append(skimage.metrics.mean_squared_error(pred, gt))
        # fmse.append(skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels)
        # psnr.append(skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min()))
        # ssim.append(skimage.metrics.structural_similarity(pred, gt, multichannel=True))
        mse.append(MSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id]))
        fmse.append(fMSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id]))
        psnr.append(PSNR(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id]))
        ssim.append(SSIM(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id]))

    return mse, fmse, psnr, ssim


def SSIM(pred, target_image, mask):
    pred = pred * mask + (target_image) * (1 - mask)
    return ssim(pred.unsqueeze(0), target_image.unsqueeze(0))


def MSE(pred, target_image, mask):
    return (mask * (pred - target_image) ** 2).mean().item()


def fMSE(pred, target_image, mask):
    diff = mask * ((pred - target_image) ** 2)
    return (diff.sum() / (diff.size(0) * mask.sum() + 1e-6)).item()


def PSNR(pred, target_image, mask):
    mse = (mask * (pred - target_image) ** 2).mean().item()
    squared_max = target_image.max().item() ** 2

    return 10 * math.log10(squared_max / (mse + 1e-6))