File size: 5,747 Bytes
f555b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os

import cv2
import torch

import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader, errors
from modules.paths import models_path

# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'

have_codeformer = False
codeformer = None


def setup_model(dirname):
    global model_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    path = modules.paths.paths.get("CodeFormer", None)
    if path is None:
        return

    try:
        from torchvision.transforms.functional import normalize
        from modules.codeformer.codeformer_arch import CodeFormer
        from basicsr.utils import img2tensor, tensor2img
        from facelib.utils.face_restoration_helper import FaceRestoreHelper
        from facelib.detection.retinaface import retinaface

        net_class = CodeFormer

        class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
            def name(self):
                return "CodeFormer"

            def __init__(self, dirname):
                self.net = None
                self.face_helper = None
                self.cmd_dir = dirname

            def create_models(self):

                if self.net is not None and self.face_helper is not None:
                    self.net.to(devices.device_codeformer)
                    return self.net, self.face_helper
                model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
                if len(model_paths) != 0:
                    ckpt_path = model_paths[0]
                else:
                    print("Unable to load codeformer model.")
                    return None, None
                net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
                checkpoint = torch.load(ckpt_path)['params_ema']
                net.load_state_dict(checkpoint)
                net.eval()

                if hasattr(retinaface, 'device'):
                    retinaface.device = devices.device_codeformer
                face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)

                self.net = net
                self.face_helper = face_helper

                return net, face_helper

            def send_model_to(self, device):
                self.net.to(device)
                self.face_helper.face_det.to(device)
                self.face_helper.face_parse.to(device)

            def restore(self, np_image, w=None):
                np_image = np_image[:, :, ::-1]

                original_resolution = np_image.shape[0:2]

                self.create_models()
                if self.net is None or self.face_helper is None:
                    return np_image

                self.send_model_to(devices.device_codeformer)

                self.face_helper.clean_all()
                self.face_helper.read_image(np_image)
                self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
                self.face_helper.align_warp_face()

                for cropped_face in self.face_helper.cropped_faces:
                    cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
                    normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
                    cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)

                    try:
                        with torch.no_grad():
                            output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
                            restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
                        del output
                        torch.cuda.empty_cache()
                    except Exception:
                        errors.report('Failed inference for CodeFormer', exc_info=True)
                        restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

                    restored_face = restored_face.astype('uint8')
                    self.face_helper.add_restored_face(restored_face)

                self.face_helper.get_inverse_affine(None)

                restored_img = self.face_helper.paste_faces_to_input_image()
                restored_img = restored_img[:, :, ::-1]

                if original_resolution != restored_img.shape[0:2]:
                    restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)

                self.face_helper.clean_all()

                if shared.opts.face_restoration_unload:
                    self.send_model_to(devices.cpu)

                return restored_img

        global have_codeformer
        have_codeformer = True

        global codeformer
        codeformer = FaceRestorerCodeFormer(dirname)
        shared.face_restorers.append(codeformer)

    except Exception:
        errors.report("Error setting up CodeFormer", exc_info=True)

   # sys.path = stored_sys_path