NPRC24 / IIR-Lab /utils.py
Artyom
IIRLab
6721043 verified
raw
history blame
8.51 kB
import torch
import numpy as np
import torch.nn.functional as F
import cv2
def padding_4x(seq_noise):
sh_im = seq_noise.size()
expanded_h = sh_im[-2]%16
if expanded_h:
expanded_h = 16-expanded_h
expanded_w = sh_im[-1]%16
if expanded_w:
expanded_w = 16-expanded_w
padexp = (0, expanded_w, 0, expanded_h)
seq_noise = F.pad(input=seq_noise, pad=padexp, mode='reflect')
return seq_noise, expanded_h, expanded_w
def depadding(seq_denoise,expanded_h, expanded_w):
if expanded_h:
seq_denoise = seq_denoise[:, :, :-expanded_h, :]
if expanded_w:
seq_denoise = seq_denoise[:, :, :, :-expanded_w]
return seq_denoise
def chunkV3(net, input_data, option, patch_h = 516, patch_w = 516, patch_h_overlap = 16, patch_w_overlap = 16):
#input_data [1,6,4,1500, 2000]
# H = input_data.shape[3]
# W = input_data.shape[4]
shape_list = input_data.shape
if option == 'image':
B, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] # 1,4,1500,2000
if option == 'RViDeformer':
B, F, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3], shape_list[4] # 1,6, 4,1500,2000
if option == 'three2one':
B, FC , H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] # 1,12,1500,2000
if option == 'image':
test_result = torch.zeros_like(input_data).cpu() # 和input的shape一样
if option == 'RViDeformer':
test_result = torch.zeros_like(input_data).cpu() # 和input的shape一样
if option == 'three2one':
test_result = torch.zeros((B, 4 , H, W)).cpu() # 和input的shape一样
# t0 = time.perf_counter()
h_index = 1
while (patch_h*h_index-patch_h_overlap*(h_index-1)) < H:
if option == 'image':
test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() #和input的shape一样 patch_h不一样
if option == 'RViDeformer':
test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu()
if option == 'three2one':
test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu()
h_begin = patch_h*(h_index-1)-patch_h_overlap*(h_index-1)
h_end = patch_h*h_index-patch_h_overlap*(h_index-1)
w_index = 1
while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W:
w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1)
w_end = patch_w*w_index-patch_w_overlap*(w_index-1)
test_patch = input_data[...,h_begin:h_end,w_begin:w_end]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
if w_index == 1:
test_horizontal_result[...,w_begin:w_end] = test_patch_result
else:
for i in range(patch_w_overlap):
test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1)
test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:]
w_index += 1
test_patch = input_data[...,h_begin:h_end,-patch_w:]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
last_range = w_end-(W-patch_w)
for i in range(last_range):
test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1)
test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:]
if h_index == 1:
test_result[...,h_begin:h_end,:] = test_horizontal_result
else:
for i in range(patch_h_overlap):
test_result[...,h_begin+i,:] = test_result[...,h_begin+i,:]*(patch_h_overlap-1-i)/(patch_h_overlap-1)+test_horizontal_result[...,i,:]*i/(patch_h_overlap-1)
test_result[...,h_begin+patch_h_overlap:h_end,:] = test_horizontal_result[...,patch_h_overlap:,:]
h_index += 1
if option == 'image':
test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() #和input的shape一样 patch_h不一样
if option == 'RViDeformer':
test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu()
if option == 'three2one':
test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu()
w_index = 1
while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W:
w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1)
w_end = patch_w*w_index-patch_w_overlap*(w_index-1)
test_patch = input_data[...,-patch_h:,w_begin:w_end]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
if w_index == 1:
test_horizontal_result[...,w_begin:w_end] = test_patch_result
else:
for i in range(patch_w_overlap):
test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1)
test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:]
w_index += 1
test_patch = input_data[...,-patch_h:,-patch_w:]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
last_range = w_end-(W-patch_w)
for i in range(last_range):
test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1)
test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:]
last_last_range = h_end-(H-patch_h)
for i in range(last_last_range):
test_result[...,H-patch_w+i,:] = test_result[...,H-patch_w+i,:]*(last_last_range-1-i)/(last_last_range-1)+test_horizontal_result[...,i,:]*i/(last_last_range-1)
test_result[...,h_end:,:] = test_horizontal_result[...,last_last_range:,:]
# t1 = time.perf_counter()
# print('Total running time: %s s' % (str(t1 - t0)))
return test_result
def calculate_psnr(img, img2, input_order='HWC'):
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = img.transpose(1, 2, 0)
img2 = img2.transpose(1, 2, 0)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(1. * 1. / mse)
def calculate_ssim(img, img2, input_order='HWC'):
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = img.transpose(1, 2, 0)
img2 = img2.transpose(1, 2, 0)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()
def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""
c1 = (0.01 * 1)**2
c2 = (0.03 * 1)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim_map.mean()