''' @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) @author: yangxy (yangtao9009@gmail.com) https://github.com/yangxy/GPEN @inproceedings{Yang2021GPEN, title={GAN Prior Embedded Network for Blind Face Restoration in the Wild}, author={Tao Yang, Peiran Ren, Xuansong Xie, and Lei Zhang}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, year={2021} } © Alibaba, 2021. For academic and non-commercial use only. ================================================== slightly modified by Kai Zhang (2021-06-03) https://github.com/cszn/KAIR How to run: step 1: Download model and model and put them into `model_zoo`. RetinaFace-R50.pth: https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth GPEN-512.pth: https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-512.pth step 2: Install ninja by `pip install ninja`; set for your own testing images step 3: `python main_test_face_enhancement.py` ================================================== ''' import os import cv2 import glob import numpy as np import torch from utils.utils_alignfaces import warp_and_crop_face, get_reference_facial_points from utils import utils_image as util from retinaface.retinaface_detection import RetinaFaceDetection from models.network_faceenhancer import FullGenerator as enhancer_net class faceenhancer(object): def __init__(self, model_path='model_zoo/GPEN-512.pth', size=512, channel_multiplier=2): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model_path = model_path self.size = size self.model = enhancer_net(self.size, 512, 8, channel_multiplier).to(self.device) self.model.load_state_dict(torch.load(self.model_path)) self.model.eval() def process(self, img): ''' img: uint8 RGB image, (W, H, 3) out: uint8 RGB image, (W, H, 3) ''' img = cv2.resize(img, (self.size, self.size)) img = util.uint2tensor4(img) img = (img - 0.5) / 0.5 img = img.to(self.device) with torch.no_grad(): out, __ = self.model(img) out = util.tensor2uint(out * 0.5 + 0.5) return out class faceenhancer_with_detection_alignment(object): def __init__(self, model_path, size=512, channel_multiplier=2): self.facedetector = RetinaFaceDetection('model_zoo/RetinaFace-R50.pth') self.faceenhancer = faceenhancer(model_path, size, channel_multiplier) self.size = size self.threshold = 0.9 self.mask = np.zeros((512, 512), np.float32) cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, cv2.LINE_AA) self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) self.mask = cv2.GaussianBlur(self.mask, (101, 101), 11) self.kernel = np.array(( [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625]), dtype="float32") # get the reference 5 landmarks position in the crop settings default_square = True inner_padding_factor = 0.25 outer_padding = (0, 0) self.reference_5pts = get_reference_facial_points( (self.size, self.size), inner_padding_factor, outer_padding, default_square) def process(self, img): ''' img: uint8 RGB image, (W, H, 3) img, orig_faces, enhanced_faces: uint8 RGB image / cropped face images ''' img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) facebs, landms = self.facedetector.detect(img) orig_faces, enhanced_faces = [], [] height, width = img.shape[:2] full_mask = np.zeros((height, width), dtype=np.float32) full_img = np.zeros(img.shape, dtype=np.uint8) for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): if faceb[4]0)] = tmp_mask[np.where(mask>0)] full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)] full_mask = full_mask[:, :, np.newaxis] img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask) return img, orig_faces, enhanced_faces if __name__=='__main__': inputdir = os.path.join('testsets', 'real_faces') outdir = os.path.join('testsets', 'real_faces_results') os.makedirs(outdir, exist_ok=True) # whether use the face detection&alignment or not need_face_detection = True if need_face_detection: enhancer = faceenhancer_with_detection_alignment(model_path=os.path.join('model_zoo','GPEN-512.pth'), size=512, channel_multiplier=2) else: enhancer = faceenhancer(model_path=os.path.join('model_zoo','GPEN-512.pth'), size=512, channel_multiplier=2) for idx, img_file in enumerate(util.get_image_paths(inputdir)): img_name, ext = os.path.splitext(os.path.basename(img_file)) img_L = util.imread_uint(img_file, n_channels=3) print('{:->4d} --> {: