import os import time import argparse import torch import torch.backends.cudnn as cudnn from utils_ours.util import setup_logger, print_args from torch.utils.data import DataLoader from dataloader.dataset import imageSet from models.archs.NAF_arch import NAF_Video from torch.nn.parallel import DistributedDataParallel import numpy as np import torch.nn.functional as F from collections import OrderedDict import torch.nn as nn from models.utils import chunkV3 import pdb from ISP_pipeline import process_pngs_isp import os import json import cv2 from skimage import io ISO = [50,125,320,640,800] a = [0.00025822882,0.000580020745,0.00141667975,0.00278965863,0.00347614807] b = [2.32350645e-06,3.1125155625e-06,8.328992952e-06,3.3315971808e-05,5.205620595e-05] #拟合 coeff_a = np.polyfit(ISO,a,1) coeff_b = np.polyfit(ISO,b,2) def main(): parser = argparse.ArgumentParser(description='imageTest') parser.add_argument('--frame', default=1, type=int) parser.add_argument('--test_dir', default = "/data/", type=str) parser.add_argument('--model_type', type=str, default='NAF_Video') parser.add_argument('--save_folder', default='/data/', type=str) parser.add_argument('--resume', default='', type=str) parser.add_argument('--testoption', default='image', type=str) parser.add_argument('--chunk', action='store_true') parser.add_argument('--debug', action='store_true') args = parser.parse_args() args.src_save_folder = '/data/' print(args.src_save_folder,'**********************') if not os.path.exists(args.src_save_folder): os.makedirs(args.src_save_folder) print(args.src_save_folder) low_iso_model = "denoise_model/low_iso.pth" mid_iso_model = "denoise_model/mid_iso.pth" high_mid_iso_model = "denoise_model/high_mid_iso.pth" high_iso_model = "denoise_model/high_iso.pth" network = NAF_Video(args).cuda() load_low_iso_net = torch.load(low_iso_model, map_location=torch.device('cuda')) load_low_iso_net_clean = OrderedDict() for k, v in load_low_iso_net.items(): if k.startswith('module.'): load_low_iso_net_clean[k[7:]] = v else: load_low_iso_net_clean[k] = v load_mid_iso_net = torch.load(mid_iso_model, map_location=torch.device('cpu')) load_mid_iso_net_clean = OrderedDict() for k, v in load_mid_iso_net.items(): if k.startswith('module.'): load_mid_iso_net_clean[k[7:]] = v else: load_mid_iso_net_clean[k] = v load_high_mid_iso_net = torch.load(high_mid_iso_model, map_location=torch.device('cpu')) load_high_mid_iso_net_clean = OrderedDict() for k, v in load_high_mid_iso_net.items(): if k.startswith('module.'): load_high_mid_iso_net_clean[k[7:]] = v else: load_high_mid_iso_net_clean[k] = v load_high_iso_net_clean = torch.load(high_iso_model, map_location=torch.device('cpu')) cudnn.benchmark = True test_dataset = imageSet(args) test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False) inference_time = [] with torch.no_grad(): for data in test_dataloader: noise = data['input'].cuda() json_path = data['json_path'][0] scene_name = os.path.splitext(os.path.basename(json_path))[0] # now let's process isp moudle json_cfa = process_pngs_isp.readjson(json_path) num_k = json_cfa['noise_profile'] iso = (num_k[0] - coeff_a[1])/coeff_a[0] if iso < 900: network.load_state_dict(load_low_iso_net_clean, strict=True) network.eval() elif iso < 1800: network.load_state_dict(load_mid_iso_net_clean, strict=True) network.eval() elif iso < 5600: network.load_state_dict(load_high_mid_iso_net_clean, strict=True) network.eval() else: network.load_state_dict(load_high_iso_net_clean, strict=True) network.eval() t0 = time.perf_counter() out = chunkV3(network, noise, args.testoption, patch_h=1024, patch_w=1024) out = torch.clamp(out, 0., 1.) # name_rgb = os.path.join(args.src_save_folder, scene_name + '_' + str(int(iso)) + '.jpg') name_rgb = os.path.join(args.src_save_folder, scene_name + '.jpg') if not os.path.exists(os.path.dirname(name_rgb)): os.makedirs(os.path.dirname(name_rgb)) out = out[0] del noise torch.cuda.empty_cache() img_pro = process_pngs_isp.isp_night_imaging(out, json_cfa, iso, do_demosaic = True, # H/2 W/2 do_channel_gain_white_balance = True, do_xyz_transform = True, do_srgb_transform = True, do_gamma_correct = True, # con do_refinement = True, # 32 bit do_to_uint8 = True, do_resize_using_pil = True, # H/8, W/8 do_fix_orientation = True ) t1 = time.perf_counter() inference_time.append(t1-t0) img_pro = cv2.cvtColor(img_pro, cv2.COLOR_RGB2BGR) cv2.imwrite(name_rgb, img_pro, [cv2.IMWRITE_PNG_COMPRESSION, 0]) print("Inference {} in {:.3f}s".format(scene_name, t1 - t0)) print(f"Average inference time: {np.mean(inference_time)} seconds") if __name__ == '__main__': main()