File size: 8,510 Bytes
6721043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import torch
import numpy as np
import torch.nn.functional as F
import cv2
def padding_4x(seq_noise):
sh_im = seq_noise.size()
expanded_h = sh_im[-2]%16
if expanded_h:
expanded_h = 16-expanded_h
expanded_w = sh_im[-1]%16
if expanded_w:
expanded_w = 16-expanded_w
padexp = (0, expanded_w, 0, expanded_h)
seq_noise = F.pad(input=seq_noise, pad=padexp, mode='reflect')
return seq_noise, expanded_h, expanded_w
def depadding(seq_denoise,expanded_h, expanded_w):
if expanded_h:
seq_denoise = seq_denoise[:, :, :-expanded_h, :]
if expanded_w:
seq_denoise = seq_denoise[:, :, :, :-expanded_w]
return seq_denoise
def chunkV3(net, input_data, option, patch_h = 516, patch_w = 516, patch_h_overlap = 16, patch_w_overlap = 16):
#input_data [1,6,4,1500, 2000]
# H = input_data.shape[3]
# W = input_data.shape[4]
shape_list = input_data.shape
if option == 'image':
B, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] # 1,4,1500,2000
if option == 'RViDeformer':
B, F, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3], shape_list[4] # 1,6, 4,1500,2000
if option == 'three2one':
B, FC , H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] # 1,12,1500,2000
if option == 'image':
test_result = torch.zeros_like(input_data).cpu() # 和input的shape一样
if option == 'RViDeformer':
test_result = torch.zeros_like(input_data).cpu() # 和input的shape一样
if option == 'three2one':
test_result = torch.zeros((B, 4 , H, W)).cpu() # 和input的shape一样
# t0 = time.perf_counter()
h_index = 1
while (patch_h*h_index-patch_h_overlap*(h_index-1)) < H:
if option == 'image':
test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() #和input的shape一样 patch_h不一样
if option == 'RViDeformer':
test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu()
if option == 'three2one':
test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu()
h_begin = patch_h*(h_index-1)-patch_h_overlap*(h_index-1)
h_end = patch_h*h_index-patch_h_overlap*(h_index-1)
w_index = 1
while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W:
w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1)
w_end = patch_w*w_index-patch_w_overlap*(w_index-1)
test_patch = input_data[...,h_begin:h_end,w_begin:w_end]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
if w_index == 1:
test_horizontal_result[...,w_begin:w_end] = test_patch_result
else:
for i in range(patch_w_overlap):
test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1)
test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:]
w_index += 1
test_patch = input_data[...,h_begin:h_end,-patch_w:]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
last_range = w_end-(W-patch_w)
for i in range(last_range):
test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1)
test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:]
if h_index == 1:
test_result[...,h_begin:h_end,:] = test_horizontal_result
else:
for i in range(patch_h_overlap):
test_result[...,h_begin+i,:] = test_result[...,h_begin+i,:]*(patch_h_overlap-1-i)/(patch_h_overlap-1)+test_horizontal_result[...,i,:]*i/(patch_h_overlap-1)
test_result[...,h_begin+patch_h_overlap:h_end,:] = test_horizontal_result[...,patch_h_overlap:,:]
h_index += 1
if option == 'image':
test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() #和input的shape一样 patch_h不一样
if option == 'RViDeformer':
test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu()
if option == 'three2one':
test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu()
w_index = 1
while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W:
w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1)
w_end = patch_w*w_index-patch_w_overlap*(w_index-1)
test_patch = input_data[...,-patch_h:,w_begin:w_end]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
if w_index == 1:
test_horizontal_result[...,w_begin:w_end] = test_patch_result
else:
for i in range(patch_w_overlap):
test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1)
test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:]
w_index += 1
test_patch = input_data[...,-patch_h:,-patch_w:]
with torch.no_grad():
test_patch_result = net(test_patch).detach().cpu()
last_range = w_end-(W-patch_w)
for i in range(last_range):
test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1)
test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:]
last_last_range = h_end-(H-patch_h)
for i in range(last_last_range):
test_result[...,H-patch_w+i,:] = test_result[...,H-patch_w+i,:]*(last_last_range-1-i)/(last_last_range-1)+test_horizontal_result[...,i,:]*i/(last_last_range-1)
test_result[...,h_end:,:] = test_horizontal_result[...,last_last_range:,:]
# t1 = time.perf_counter()
# print('Total running time: %s s' % (str(t1 - t0)))
return test_result
def calculate_psnr(img, img2, input_order='HWC'):
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = img.transpose(1, 2, 0)
img2 = img2.transpose(1, 2, 0)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(1. * 1. / mse)
def calculate_ssim(img, img2, input_order='HWC'):
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = img.transpose(1, 2, 0)
img2 = img2.transpose(1, 2, 0)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()
def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""
c1 = (0.01 * 1)**2
c2 = (0.03 * 1)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim_map.mean() |