|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import cv2 |
|
def padding_4x(seq_noise): |
|
sh_im = seq_noise.size() |
|
expanded_h = sh_im[-2]%16 |
|
|
|
if expanded_h: |
|
expanded_h = 16-expanded_h |
|
expanded_w = sh_im[-1]%16 |
|
if expanded_w: |
|
expanded_w = 16-expanded_w |
|
|
|
padexp = (0, expanded_w, 0, expanded_h) |
|
seq_noise = F.pad(input=seq_noise, pad=padexp, mode='reflect') |
|
return seq_noise, expanded_h, expanded_w |
|
|
|
def depadding(seq_denoise,expanded_h, expanded_w): |
|
if expanded_h: |
|
seq_denoise = seq_denoise[:, :, :-expanded_h, :] |
|
if expanded_w: |
|
seq_denoise = seq_denoise[:, :, :, :-expanded_w] |
|
return seq_denoise |
|
def chunkV3(net, input_data, option, patch_h = 516, patch_w = 516, patch_h_overlap = 16, patch_w_overlap = 16): |
|
|
|
|
|
|
|
|
|
|
|
shape_list = input_data.shape |
|
|
|
if option == 'image': |
|
B, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] |
|
if option == 'RViDeformer': |
|
B, F, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3], shape_list[4] |
|
if option == 'three2one': |
|
B, FC , H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] |
|
|
|
if option == 'image': |
|
test_result = torch.zeros_like(input_data).cpu() |
|
if option == 'RViDeformer': |
|
test_result = torch.zeros_like(input_data).cpu() |
|
if option == 'three2one': |
|
test_result = torch.zeros((B, 4 , H, W)).cpu() |
|
|
|
|
|
|
|
h_index = 1 |
|
while (patch_h*h_index-patch_h_overlap*(h_index-1)) < H: |
|
if option == 'image': |
|
test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() |
|
if option == 'RViDeformer': |
|
test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu() |
|
if option == 'three2one': |
|
test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu() |
|
|
|
h_begin = patch_h*(h_index-1)-patch_h_overlap*(h_index-1) |
|
h_end = patch_h*h_index-patch_h_overlap*(h_index-1) |
|
w_index = 1 |
|
while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W: |
|
w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1) |
|
w_end = patch_w*w_index-patch_w_overlap*(w_index-1) |
|
test_patch = input_data[...,h_begin:h_end,w_begin:w_end] |
|
|
|
with torch.no_grad(): |
|
test_patch_result = net(test_patch).detach().cpu() |
|
|
|
if w_index == 1: |
|
test_horizontal_result[...,w_begin:w_end] = test_patch_result |
|
else: |
|
for i in range(patch_w_overlap): |
|
test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1) |
|
test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:] |
|
w_index += 1 |
|
|
|
test_patch = input_data[...,h_begin:h_end,-patch_w:] |
|
|
|
with torch.no_grad(): |
|
test_patch_result = net(test_patch).detach().cpu() |
|
last_range = w_end-(W-patch_w) |
|
|
|
for i in range(last_range): |
|
test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1) |
|
test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:] |
|
|
|
if h_index == 1: |
|
test_result[...,h_begin:h_end,:] = test_horizontal_result |
|
else: |
|
for i in range(patch_h_overlap): |
|
test_result[...,h_begin+i,:] = test_result[...,h_begin+i,:]*(patch_h_overlap-1-i)/(patch_h_overlap-1)+test_horizontal_result[...,i,:]*i/(patch_h_overlap-1) |
|
test_result[...,h_begin+patch_h_overlap:h_end,:] = test_horizontal_result[...,patch_h_overlap:,:] |
|
h_index += 1 |
|
|
|
if option == 'image': |
|
test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() |
|
if option == 'RViDeformer': |
|
test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu() |
|
if option == 'three2one': |
|
test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu() |
|
|
|
w_index = 1 |
|
while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W: |
|
w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1) |
|
w_end = patch_w*w_index-patch_w_overlap*(w_index-1) |
|
test_patch = input_data[...,-patch_h:,w_begin:w_end] |
|
|
|
with torch.no_grad(): |
|
test_patch_result = net(test_patch).detach().cpu() |
|
|
|
if w_index == 1: |
|
test_horizontal_result[...,w_begin:w_end] = test_patch_result |
|
else: |
|
for i in range(patch_w_overlap): |
|
test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1) |
|
test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:] |
|
w_index += 1 |
|
|
|
test_patch = input_data[...,-patch_h:,-patch_w:] |
|
|
|
with torch.no_grad(): |
|
test_patch_result = net(test_patch).detach().cpu() |
|
last_range = w_end-(W-patch_w) |
|
for i in range(last_range): |
|
test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1) |
|
test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:] |
|
|
|
last_last_range = h_end-(H-patch_h) |
|
for i in range(last_last_range): |
|
test_result[...,H-patch_w+i,:] = test_result[...,H-patch_w+i,:]*(last_last_range-1-i)/(last_last_range-1)+test_horizontal_result[...,i,:]*i/(last_last_range-1) |
|
test_result[...,h_end:,:] = test_horizontal_result[...,last_last_range:,:] |
|
|
|
|
|
|
|
|
|
return test_result |
|
|
|
|
|
def calculate_psnr(img, img2, input_order='HWC'): |
|
|
|
|
|
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') |
|
if input_order not in ['HWC', 'CHW']: |
|
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') |
|
|
|
img = img.transpose(1, 2, 0) |
|
img2 = img2.transpose(1, 2, 0) |
|
|
|
|
|
img = img.astype(np.float64) |
|
img2 = img2.astype(np.float64) |
|
|
|
mse = np.mean((img - img2)**2) |
|
if mse == 0: |
|
return float('inf') |
|
return 10. * np.log10(1. * 1. / mse) |
|
|
|
|
|
def calculate_ssim(img, img2, input_order='HWC'): |
|
|
|
|
|
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') |
|
if input_order not in ['HWC', 'CHW']: |
|
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') |
|
|
|
|
|
img = img.transpose(1, 2, 0) |
|
img2 = img2.transpose(1, 2, 0) |
|
|
|
|
|
img = img.astype(np.float64) |
|
img2 = img2.astype(np.float64) |
|
|
|
ssims = [] |
|
for i in range(img.shape[2]): |
|
ssims.append(_ssim(img[..., i], img2[..., i])) |
|
return np.array(ssims).mean() |
|
|
|
def _ssim(img, img2): |
|
"""Calculate SSIM (structural similarity) for one channel images. |
|
|
|
It is called by func:`calculate_ssim`. |
|
|
|
Args: |
|
img (ndarray): Images with range [0, 255] with order 'HWC'. |
|
img2 (ndarray): Images with range [0, 255] with order 'HWC'. |
|
|
|
Returns: |
|
float: SSIM result. |
|
""" |
|
|
|
c1 = (0.01 * 1)**2 |
|
c2 = (0.03 * 1)**2 |
|
kernel = cv2.getGaussianKernel(11, 1.5) |
|
window = np.outer(kernel, kernel.transpose()) |
|
|
|
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] |
|
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] |
|
mu1_sq = mu1**2 |
|
mu2_sq = mu2**2 |
|
mu1_mu2 = mu1 * mu2 |
|
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq |
|
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq |
|
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 |
|
|
|
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) |
|
return ssim_map.mean() |