import cv2 import numpy as np import torch from PIL import Image from torch import nn from nets.cyclegan import Generator from utils.utils import (cvtColor, postprocess_output, preprocess_input, resize_image, show_config) class CYCLEGAN(object): _defaults = { #-----------------------------------------------# # model_path指向logs文件夹下的权值文件 #-----------------------------------------------# "model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth', #-----------------------------------------------# # 输入图像大小的设置 #-----------------------------------------------# "input_shape" : [112, 112], #-------------------------------# # 是否进行不失真的resize #-------------------------------# "letterbox_image" : True, #-------------------------------# # 是否使用Cuda # 没有GPU可以设置成False #-------------------------------# "cuda" : True, } #---------------------------------------------------# # 初始化CYCLEGAN #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) for name, value in kwargs.items(): setattr(self, name, value) self._defaults[name] = value self.generate() show_config(**self._defaults) def generate(self): #----------------------------------------# # 创建GAN模型 #----------------------------------------# self.net = Generator(upscale=1, img_size=tuple(self.input_shape), window_size=7, img_range=1., depths=[3, 3, 3, 3], embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.net.load_state_dict(torch.load(self.model_path, map_location=device)) self.net = self.net.eval() print('{} model loaded.'.format(self.model_path)) if self.cuda: self.net = nn.DataParallel(self.net) self.net = self.net.cuda() #---------------------------------------------------# # 生成1x1的图片 #---------------------------------------------------# def detect_image(self, image): #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------# # 获得高宽 #---------------------------------------------------# orininal_h = np.array(image).shape[0] orininal_w = np.array(image).shape[1] #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize # 也可以直接resize进行识别 #---------------------------------------------------------# image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) #---------------------------------------------------------# # 添加上batch_size维度 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------# # 图片传入网络进行预测 #---------------------------------------------------# pr = self.net(images)[0] #---------------------------------------------------# # 转为numpy #---------------------------------------------------# pr = pr.permute(1, 2, 0).cpu().numpy() #--------------------------------------# # 将灰条部分截取掉 #--------------------------------------# if nw is not None: pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] #---------------------------------------------------# # 进行图片的resize #---------------------------------------------------# pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) image = postprocess_output(pr) image = np.clip(image, 0, 255) image = Image.fromarray(np.uint8(image)) return image