import os import cv2 import json import torch import torchvision.transforms as transforms from CPNet_model import LiteAWBISPNet import torchvision import numpy as np from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four import time from net.mwrcanet import Net import torch.nn as nn from PIL import Image import torch.nn.functional as F #######Set Raw path########### Rpath = './Input' image_files = [] ####### Temp ############################### infer_times = [] #######Color Matrix from Baseline############# color_matrix = [1.06835938, -0.29882812, -0.14257812, -0.43164062, 1.35546875, 0.05078125, -0.1015625, 0.24414062, 0.5859375] #######Data Transfer########################### transforms_ = [ transforms.ToTensor(), transforms.Resize([768,1024])] transform = transforms.Compose(transforms_) transforms_ = [ transforms.ToTensor()] transformo = transforms.Compose(transforms_) ########Load the pretrained refinement model#### model = LiteAWBISPNet() model.cuda() model.load_state_dict(torch.load('./model_zoo/CC2.pth') ) ######load pretrianed Denoised model############## last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth' dn_net = Net() dn_model = nn.DataParallel(dn_net).cuda() tmp_ckpt = torch.load(last_ckpt) pretrained_dict = tmp_ckpt['state_dict'] model_dict=dn_model.state_dict() pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict} assert(len(pretrained_dict)==len(pretrained_dict_update)) assert(len(pretrained_dict_update)==len(model_dict)) model_dict.update(pretrained_dict_update) dn_model.load_state_dict(model_dict) ############################Start Processing!######### for filename in os.listdir(Rpath): if os.path.splitext(filename)[-1].lower() == ".png": image_files.append(filename) with torch.no_grad(): for fp in image_files: fp = os.path.join(Rpath, fp) mn = os.path.splitext(fp)[-2] mf = str(mn) + '.json' raw_image = cv2.imread(fp, -1) with open(mf, 'r') as file: data = json.load(file) ############Bleack & Whilte########################## time_BL_S = time.time() raw_image = (raw_image.astype(np.float32) - 256.) raw_image = raw_image / (4095. - 256.) raw_image = np.clip(raw_image, 0.0, 1.0) ############# Binning ############################ raw_image = binning(raw_image,data) ############# Down sample ########################### raw_image = cv2.resize(raw_image, [1024,768]) ############ Raw Denoise ########################## Temp_I = Four2One(raw_image) Temp_I = transformo(Temp_I).unsqueeze(0).cuda() Temp_I = dn_model(Temp_I) Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu()) raw_image = One2Four(Temp_I) #raw_image = cv2.resize(raw_image, [1024,768]) #############White Balance, Color M, Vignet ######### raw_image = white_balance(raw_image, data['as_shot_neutral']) raw_image = apply_color_space_transform(raw_image, color_matrix) raw_image = transform_xyz_to_srgb(raw_image) raw_image = apply_gamma(raw_image) #############Refinement############################# Source = transform(raw_image).unsqueeze(0).float().cuda() Out = model(Source) #################Saving############################# Out = Out.clip(0,1) OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32) OA = OA*255. OA = OA.astype(np.uint8) OA = fix_orientation(OA,data["orientation"]) time_Save_F = time.time() OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR) OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA) infer_times.append(time_Save_F-time_BL_S) print(f"Average inference time: {np.mean(infer_times)} seconds")