fffiloni commited on
Commit
8999ad1
1 Parent(s): b65ff1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -3
app.py CHANGED
@@ -1,5 +1,152 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- gr.Interface.load(
4
- "spaces/akhaliq/ArcaneGAN", inputs="webcam", title="Remove your webcam background!"
5
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+ os.system("pip -qq install facenet_pytorch")
4
+ from facenet_pytorch import MTCNN
5
+ from torchvision import transforms
6
+ import torch, PIL
7
+ from tqdm.notebook import tqdm
8
  import gradio as gr
9
+ import torch
10
 
11
+ modelarcanev4 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.4", filename="ArcaneGANv0.4.jit")
12
+ modelarcanev3 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.3", filename="ArcaneGANv0.3.jit")
13
+ modelarcanev2 = hf_hub_download(repo_id="akhaliq/ArcaneGANv0.2", filename="ArcaneGANv0.2.jit")
14
+
15
+
16
+ mtcnn = MTCNN(image_size=256, margin=80)
17
+
18
+ # simplest ye olde trustworthy MTCNN for face detection with landmarks
19
+ def detect(img):
20
+
21
+ # Detect faces
22
+ batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
23
+ # Select faces
24
+ if not mtcnn.keep_all:
25
+ batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(
26
+ batch_boxes, batch_probs, batch_points, img, method=mtcnn.selection_method
27
+ )
28
+
29
+ return batch_boxes, batch_points
30
+
31
+ # my version of isOdd, should make a separate repo for it :D
32
+ def makeEven(_x):
33
+ return _x if (_x % 2 == 0) else _x+1
34
+
35
+ # the actual scaler function
36
+ def scale(boxes, _img, max_res=1_500_000, target_face=256, fixed_ratio=0, max_upscale=2, VERBOSE=False):
37
+
38
+ x, y = _img.size
39
+
40
+ ratio = 2 #initial ratio
41
+
42
+ #scale to desired face size
43
+ if (boxes is not None):
44
+ if len(boxes)>0:
45
+ ratio = target_face/max(boxes[0][2:]-boxes[0][:2]);
46
+ ratio = min(ratio, max_upscale)
47
+ if VERBOSE: print('up by', ratio)
48
+
49
+ if fixed_ratio>0:
50
+ if VERBOSE: print('fixed ratio')
51
+ ratio = fixed_ratio
52
+
53
+ x*=ratio
54
+ y*=ratio
55
+
56
+ #downscale to fit into max res
57
+ res = x*y
58
+ if res > max_res:
59
+ ratio = pow(res/max_res,1/2);
60
+ if VERBOSE: print(ratio)
61
+ x=int(x/ratio)
62
+ y=int(y/ratio)
63
+
64
+ #make dimensions even, because usually NNs fail on uneven dimensions due skip connection size mismatch
65
+ x = makeEven(int(x))
66
+ y = makeEven(int(y))
67
+
68
+ size = (x, y)
69
+
70
+ return _img.resize(size)
71
+
72
+ """
73
+ A useful scaler algorithm, based on face detection.
74
+ Takes PIL.Image, returns a uniformly scaled PIL.Image
75
+ boxes: a list of detected bboxes
76
+ _img: PIL.Image
77
+ max_res: maximum pixel area to fit into. Use to stay below the VRAM limits of your GPU.
78
+ target_face: desired face size. Upscale or downscale the whole image to fit the detected face into that dimension.
79
+ fixed_ratio: fixed scale. Ignores the face size, but doesn't ignore the max_res limit.
80
+ max_upscale: maximum upscale ratio. Prevents from scaling images with tiny faces to a blurry mess.
81
+ """
82
+
83
+ def scale_by_face_size(_img, max_res=1_500_000, target_face=256, fix_ratio=0, max_upscale=2, VERBOSE=False):
84
+ boxes = None
85
+ boxes, _ = detect(_img)
86
+ if VERBOSE: print('boxes',boxes)
87
+ img_resized = scale(boxes, _img, max_res, target_face, fix_ratio, max_upscale, VERBOSE)
88
+ return img_resized
89
+
90
+
91
+ size = 256
92
+
93
+ means = [0.485, 0.456, 0.406]
94
+ stds = [0.229, 0.224, 0.225]
95
+
96
+ t_stds = torch.tensor(stds).cuda().half()[:,None,None]
97
+ t_means = torch.tensor(means).cuda().half()[:,None,None]
98
+
99
+ def makeEven(_x):
100
+ return int(_x) if (_x % 2 == 0) else int(_x+1)
101
+
102
+ img_transforms = transforms.Compose([
103
+ transforms.ToTensor(),
104
+ transforms.Normalize(means,stds)])
105
+
106
+ def tensor2im(var):
107
+ return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
108
+
109
+ def proc_pil_img(input_image, model):
110
+ transformed_image = img_transforms(input_image)[None,...].cuda().half()
111
+
112
+ with torch.no_grad():
113
+ result_image = model(transformed_image)[0]
114
+ output_image = tensor2im(result_image)
115
+ output_image = output_image.detach().cpu().numpy().astype('uint8')
116
+ output_image = PIL.Image.fromarray(output_image)
117
+ return output_image
118
+
119
+
120
+
121
+ modelv4 = torch.jit.load(modelarcanev4).eval().cuda().half()
122
+ modelv3 = torch.jit.load(modelarcanev3).eval().cuda().half()
123
+ modelv2 = torch.jit.load(modelarcanev2).eval().cuda().half()
124
+
125
+ def process(im, version):
126
+ if version == 'version 0.4':
127
+ im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
128
+ res = proc_pil_img(im, modelv4)
129
+ elif version == 'version 0.3':
130
+ im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
131
+ res = proc_pil_img(im, modelv3)
132
+ else:
133
+ im = scale_by_face_size(im, target_face=256, max_res=1_500_000, max_upscale=1)
134
+ res = proc_pil_img(im, modelv2)
135
+ return res
136
+
137
+ title = "ArcaneGAN"
138
+ description = "Gradio demo for ArcaneGAN, portrait to Arcane style. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
139
+ article = "<div style='text-align: center;'>ArcaneGan by <a href='https://twitter.com/devdef' target='_blank'>Alexander S</a> | <a href='https://github.com/Sxela/ArcaneGAN' target='_blank'>Github Repo</a> | <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_arcanegan' alt='visitor badge'></center></div>"
140
+
141
+ gr.Interface(
142
+ process,
143
+ [gr.inputs.Image(type="pil", label="Input"),gr.inputs.Radio(choices=['version 0.2','version 0.3','version 0.4'], type="value", default='version 0.4', label='version')
144
+ ],
145
+ gr.outputs.Image(type="pil", label="Output"),
146
+ title=title,
147
+ description=description,
148
+ article=article,
149
+ examples=[['bill.png','version 0.3'],['keanu.png','version 0.4'],['will.jpeg','version 0.4']],
150
+ allow_flagging=False,
151
+ allow_screenshot=False
152
+ ).launch()