import torch import numpy as np import torch.nn.functional as F import cv2 def padding_4x(seq_noise): sh_im = seq_noise.size() expanded_h = sh_im[-2]%16 if expanded_h: expanded_h = 16-expanded_h expanded_w = sh_im[-1]%16 if expanded_w: expanded_w = 16-expanded_w padexp = (0, expanded_w, 0, expanded_h) seq_noise = F.pad(input=seq_noise, pad=padexp, mode='reflect') return seq_noise, expanded_h, expanded_w def depadding(seq_denoise,expanded_h, expanded_w): if expanded_h: seq_denoise = seq_denoise[:, :, :-expanded_h, :] if expanded_w: seq_denoise = seq_denoise[:, :, :, :-expanded_w] return seq_denoise def chunkV3(net, input_data, option, patch_h = 516, patch_w = 516, patch_h_overlap = 16, patch_w_overlap = 16): #input_data [1,6,4,1500, 2000] # H = input_data.shape[3] # W = input_data.shape[4] shape_list = input_data.shape if option == 'image': B, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] # 1,4,1500,2000 if option == 'RViDeformer': B, F, C, H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3], shape_list[4] # 1,6, 4,1500,2000 if option == 'three2one': B, FC , H, W = shape_list[0], shape_list[1], shape_list[2], shape_list[3] # 1,12,1500,2000 if option == 'image': test_result = torch.zeros_like(input_data).cpu() # 和input的shape一样 if option == 'RViDeformer': test_result = torch.zeros_like(input_data).cpu() # 和input的shape一样 if option == 'three2one': test_result = torch.zeros((B, 4 , H, W)).cpu() # 和input的shape一样 # t0 = time.perf_counter() h_index = 1 while (patch_h*h_index-patch_h_overlap*(h_index-1)) < H: if option == 'image': test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() #和input的shape一样 patch_h不一样 if option == 'RViDeformer': test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu() if option == 'three2one': test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu() h_begin = patch_h*(h_index-1)-patch_h_overlap*(h_index-1) h_end = patch_h*h_index-patch_h_overlap*(h_index-1) w_index = 1 while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W: w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1) w_end = patch_w*w_index-patch_w_overlap*(w_index-1) test_patch = input_data[...,h_begin:h_end,w_begin:w_end] with torch.no_grad(): test_patch_result = net(test_patch).detach().cpu() if w_index == 1: test_horizontal_result[...,w_begin:w_end] = test_patch_result else: for i in range(patch_w_overlap): test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1) test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:] w_index += 1 test_patch = input_data[...,h_begin:h_end,-patch_w:] with torch.no_grad(): test_patch_result = net(test_patch).detach().cpu() last_range = w_end-(W-patch_w) for i in range(last_range): test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1) test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:] if h_index == 1: test_result[...,h_begin:h_end,:] = test_horizontal_result else: for i in range(patch_h_overlap): test_result[...,h_begin+i,:] = test_result[...,h_begin+i,:]*(patch_h_overlap-1-i)/(patch_h_overlap-1)+test_horizontal_result[...,i,:]*i/(patch_h_overlap-1) test_result[...,h_begin+patch_h_overlap:h_end,:] = test_horizontal_result[...,patch_h_overlap:,:] h_index += 1 if option == 'image': test_horizontal_result = torch.zeros((B,C,patch_h,W)).cpu() #和input的shape一样 patch_h不一样 if option == 'RViDeformer': test_horizontal_result = torch.zeros((B, F, C, patch_h, W)).cpu() if option == 'three2one': test_horizontal_result = torch.zeros((B, 4, patch_h, W)).cpu() w_index = 1 while (patch_w*w_index-patch_w_overlap*(w_index-1)) < W: w_begin = patch_w*(w_index-1)-patch_w_overlap*(w_index-1) w_end = patch_w*w_index-patch_w_overlap*(w_index-1) test_patch = input_data[...,-patch_h:,w_begin:w_end] with torch.no_grad(): test_patch_result = net(test_patch).detach().cpu() if w_index == 1: test_horizontal_result[...,w_begin:w_end] = test_patch_result else: for i in range(patch_w_overlap): test_horizontal_result[...,w_begin+i] = test_horizontal_result[...,w_begin+i]*(patch_w_overlap-1-i)/(patch_w_overlap-1)+test_patch_result[...,i]*i/(patch_w_overlap-1) test_horizontal_result[...,w_begin+patch_w_overlap:w_end] = test_patch_result[...,patch_w_overlap:] w_index += 1 test_patch = input_data[...,-patch_h:,-patch_w:] with torch.no_grad(): test_patch_result = net(test_patch).detach().cpu() last_range = w_end-(W-patch_w) for i in range(last_range): test_horizontal_result[...,W-patch_w+i] = test_horizontal_result[...,W-patch_w+i]*(last_range-1-i)/(last_range-1)+test_patch_result[...,i]*i/(last_range-1) test_horizontal_result[...,w_end:] = test_patch_result[...,last_range:] last_last_range = h_end-(H-patch_h) for i in range(last_last_range): test_result[...,H-patch_w+i,:] = test_result[...,H-patch_w+i,:]*(last_last_range-1-i)/(last_last_range-1)+test_horizontal_result[...,i,:]*i/(last_last_range-1) test_result[...,h_end:,:] = test_horizontal_result[...,last_last_range:,:] # t1 = time.perf_counter() # print('Total running time: %s s' % (str(t1 - t0))) return test_result def calculate_psnr(img, img2, input_order='HWC'): assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') if input_order not in ['HWC', 'CHW']: raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') img = img.transpose(1, 2, 0) img2 = img2.transpose(1, 2, 0) img = img.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img - img2)**2) if mse == 0: return float('inf') return 10. * np.log10(1. * 1. / mse) def calculate_ssim(img, img2, input_order='HWC'): assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') if input_order not in ['HWC', 'CHW']: raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"') img = img.transpose(1, 2, 0) img2 = img2.transpose(1, 2, 0) img = img.astype(np.float64) img2 = img2.astype(np.float64) ssims = [] for i in range(img.shape[2]): ssims.append(_ssim(img[..., i], img2[..., i])) return np.array(ssims).mean() def _ssim(img, img2): """Calculate SSIM (structural similarity) for one channel images. It is called by func:`calculate_ssim`. Args: img (ndarray): Images with range [0, 255] with order 'HWC'. img2 (ndarray): Images with range [0, 255] with order 'HWC'. Returns: float: SSIM result. """ c1 = (0.01 * 1)**2 c2 = (0.03 * 1)**2 kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11 mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] mu1_sq = mu1**2 mu2_sq = mu2**2 mu1_mu2 = mu1 * mu2 sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)) return ssim_map.mean()