Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from torch.autograd import Variable | |
import numpy as np | |
import cv2 | |
# from vgg19_loss import VGG19Loss | |
# import pytorch_ssim | |
# from .vgg19_loss import VGG19Loss | |
from . import pytorch_ssim | |
from abc import ABC, abstractmethod | |
from collections import OrderedDict | |
class abs_loss(ABC): | |
def loss(self, gt_img, pred_img): | |
pass | |
class norm_loss(abs_loss): | |
def __init__(self, norm=1): | |
self.norm = norm | |
def loss(self, gt_img, pred_img): | |
""" M * (I-I') """ | |
b, c, h, w = gt_img.shape | |
return torch.norm(gt_img-pred_img, self.norm)/(h * w * b) | |
class ssim_loss(abs_loss): | |
def __init__(self, window_size=11, channel=1): | |
""" Let's try mean ssim! | |
""" | |
self.channel = channel | |
self.window_size = window_size | |
self.window = self.create_mean_window(window_size, channel) | |
def loss(self, gt_img, pred_img): | |
b, c, h, w = gt_img.shape | |
if c != self.channel: | |
self.channel = c | |
self.window = self.create_mean_window(self.window_size, self.channel) | |
self.window = self.window.to(gt_img).type_as(gt_img) | |
l = 1.0 - self.ssim_compute(gt_img, pred_img) | |
return l | |
def create_mean_window(self, window_size, channel): | |
window = Variable(torch.ones(channel, 1, window_size, window_size).float()) | |
window = window/(window_size * window_size) | |
return window | |
def ssim_compute(self, gt_img, pred_img): | |
window = self.window | |
window_size = self.window_size | |
channel = self.channel | |
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel) | |
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1*mu2 | |
sigma1_sq = F.conv2d(gt_img*gt_img, window, padding = window_size//2, groups = channel) - mu1_sq | |
sigma2_sq = F.conv2d(pred_img*pred_img, window, padding = window_size//2, groups = channel) - mu2_sq | |
sigma12 = F.conv2d(gt_img*pred_img, window, padding = window_size//2, groups = channel) - mu1_mu2 | |
C1 = 0.01**2 | |
C2 = 0.03**2 | |
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) | |
return ssim_map.mean() | |
class hierarchical_ssim_loss(abs_loss): | |
def __init__(self, patch_list: list): | |
self.ssim_loss_list = [pytorch_ssim.SSIM(window_size=ws) for ws in patch_list] | |
def loss(self, gt_img, pred_img): | |
b, c, h, w = gt_img.shape | |
total_loss = 0.0 | |
for loss_func in self.ssim_loss_list: | |
total_loss += (1.0-loss_func(gt_img, pred_img)) | |
return total_loss/b | |
# class vgg_loss(abs_loss): | |
# def __init__(self): | |
# self.vgg19_ = VGG19Loss() | |
# def loss(self, gt_img, pred_img): | |
# b, c, h, w = gt_img.shape | |
# v = self.vgg19_(gt_img, pred_img, pred_img.device) | |
# return v/b | |
class grad_loss(abs_loss): | |
def __init__(self, k=4): | |
self.k = 4 | |
def loss(self, disp_img, rgb_img=None): | |
""" Note, gradient loss should be weighted by an edge-aware weight | |
""" | |
b, c, h, w = disp_img.shape | |
grad_loss = 0.0 | |
for i in range(self.k): | |
div_factor = 2 ** i | |
cur_transform = T.Resize([h // div_factor, ]) | |
# cur_diff = cur_transform(diff) | |
# cur_diff_dx, cur_diff_dy = self.img_grad(cur_diff) | |
cur_disp = cur_transform(disp_img) | |
cur_disp_dx, cur_disp_dy = self.img_grad(cur_disp) | |
if rgb_img is not None: | |
cur_rgb = cur_transform(rgb_img) | |
cur_rgb_dx, cur_rgb_dy = self.img_grad(cur_rgb) | |
cur_rgb_dx = torch.exp(-torch.mean(torch.abs(cur_rgb_dx), dim=1, keepdims=True)) | |
cur_rgb_dy = torch.exp(-torch.mean(torch.abs(cur_rgb_dy), dim=1, keepdims=True)) | |
grad_loss += (torch.sum(torch.abs(cur_disp_dx) * cur_rgb_dx) + torch.sum(torch.abs(cur_disp_dy) * cur_rgb_dy)) / (h * w * self.k) | |
else: | |
grad_loss += (torch.sum(torch.abs(cur_disp_dx)) + torch.sum(torch.abs(cur_disp_dy))) / (h * w * self.k) | |
return grad_loss/b | |
def gloss(self, gt, pred): | |
""" Loss on the gradient domain | |
""" | |
b, c, h, w = gt.shape | |
gt_dx, gt_dy = self.img_grad(gt) | |
pred_dx, pred_dy = self.img_grad(pred) | |
loss = (gt_dx-pred_dx) ** 2 + (gt_dy - pred_dy) ** 2 | |
return loss.sum()/(b * h * w) | |
def laploss(self, pred): | |
b, c, h, w = pred.shape | |
lap = self.img_laplacian(pred) | |
return torch.abs(lap).sum()/(b * h * w) | |
def img_laplacian(self, img): | |
b, c, h, w = img.shape | |
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]]) | |
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img) | |
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1) | |
return lap | |
def img_grad(self, img): | |
""" Comptue image gradient by sobel filtering | |
img: B x C x H x W | |
""" | |
b, c, h, w = img.shape | |
ysobel = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) | |
xsobel = ysobel.transpose(0,1) | |
xsobel_kernel = xsobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img) | |
ysobel_kernel = ysobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img) | |
dx = F.conv2d(img, xsobel_kernel, padding=1, stride=1) | |
dy = F.conv2d(img, ysobel_kernel, padding=1, stride=1) | |
return dx, dy | |
class sharp_loss(abs_loss): | |
""" Sharpness term | |
1. laplacian | |
2. image contrast | |
3. image variance | |
""" | |
def __init__(self, window_size=11, channel=1): | |
self.window_size = window_size | |
self.channel = channel | |
self.window = self.create_mean_window(window_size, self.channel) | |
def loss(self, gt_img, pred_img): | |
""" Note, gradient loss should be weighted by an edge-aware weight | |
""" | |
b, c, h, w = gt_img.shape | |
if c != self.channel: | |
self.channel = c | |
self.window = self.create_mean_window(self.window_size, self.channel) | |
self.window = self.window.to(gt_img).type_as(gt_img) | |
channel = self.channel | |
window = self.window | |
window_size = self.window_size | |
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel) + 1e-6 | |
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel) + 1e-6 | |
constrast1 = torch.absolute((gt_img - mu1)/mu1) | |
constrast2 = torch.absolute((pred_img - mu2)/mu2) | |
variance1 = (gt_img-mu1) ** 2 | |
variance2 = (pred_img-mu2) ** 2 | |
laplacian1 = self.img_laplacian(gt_img) | |
laplacian2 = self.img_laplacian(pred_img) | |
S1 = -laplacian1 - constrast1 - variance1 | |
S2 = -laplacian2 - constrast2 - variance2 | |
# import pdb; pdb.set_trace() | |
total = torch.absolute(S1-S2).mean() | |
return total | |
def img_laplacian(self, img): | |
b, c, h, w = img.shape | |
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]]) | |
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img) | |
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1) | |
return lap | |
def create_mean_window(self, window_size, channel): | |
window = Variable(torch.ones(channel, 1, window_size, window_size).float()) | |
window = window/(window_size * window_size) | |
return window | |
if __name__ == '__main__': | |
a = torch.rand(3,3,128,128) | |
b = torch.rand(3,3,128,128) | |
ssim = ssim_loss() | |
loss = ssim.loss(a, b) | |
print(loss.shape, loss) | |
loss = ssim.loss(a, a) | |
print(loss.shape, loss) | |
loss = ssim.loss(b, b) | |
print(loss.shape, loss) | |
grad = grad_loss() | |
loss = grad.loss(a, [b, b]) | |
print(loss.shape, loss) | |
sharp = sharp_loss() | |
loss = sharp.loss(a, b) | |
print(loss.shape, loss) | |