Spaces:
Running
Running
Add other style types
Browse files
app.py
CHANGED
@@ -129,15 +129,22 @@ def postprocess(tensor: torch.Tensor) -> PIL.Image.Image:
|
|
129 |
@torch.inference_mode()
|
130 |
def run(
|
131 |
image,
|
132 |
-
|
|
|
133 |
dlib_landmark_model,
|
134 |
encoder: nn.Module,
|
135 |
-
|
136 |
-
|
137 |
transform: Callable,
|
138 |
device: torch.device,
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
stylename = list(exstyles.keys())[style_id]
|
142 |
|
143 |
image = align_face(filepath=image.name, predictor=dlib_landmark_model)
|
@@ -181,7 +188,11 @@ def run(
|
|
181 |
img_gen1 = postprocess(img_gen[1])
|
182 |
img_gen2 = postprocess(img_gen2[0])
|
183 |
|
184 |
-
|
|
|
|
|
|
|
|
|
185 |
|
186 |
return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
|
187 |
|
@@ -192,43 +203,60 @@ def main():
|
|
192 |
args = parse_args()
|
193 |
device = torch.device(args.device)
|
194 |
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
download_cartoon_images()
|
199 |
dlib_landmark_model = create_dlib_landmark_model()
|
200 |
encoder = load_encoder(device)
|
201 |
-
generator = load_generator(style_type, device)
|
202 |
-
exstyles = load_exstylecode(style_type)
|
203 |
transform = create_transform()
|
204 |
|
205 |
func = functools.partial(run,
|
206 |
dlib_landmark_model=dlib_landmark_model,
|
207 |
encoder=encoder,
|
208 |
-
|
209 |
-
|
210 |
transform=transform,
|
211 |
-
device=device
|
212 |
-
style_image_dir=style_image_dir)
|
213 |
func = functools.update_wrapper(func, run)
|
214 |
|
215 |
repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
|
216 |
title = 'williamyang1991/DualStyleGAN'
|
217 |
description = f"""A demo for {repo_url}
|
218 |
|
219 |
-
You can select style images from the table below.
|
220 |
"""
|
221 |
article = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
|
222 |
|
223 |
image_paths = sorted(pathlib.Path('images').glob('*'))
|
224 |
-
examples = [[path.as_posix(), 26] for path in image_paths]
|
225 |
|
226 |
gr.Interface(
|
227 |
func,
|
228 |
[
|
229 |
-
gr.inputs.Image(type='file', label='Image'),
|
230 |
-
gr.inputs.
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
232 |
],
|
233 |
[
|
234 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|
|
|
129 |
@torch.inference_mode()
|
130 |
def run(
|
131 |
image,
|
132 |
+
style_type: str,
|
133 |
+
style_id: float,
|
134 |
dlib_landmark_model,
|
135 |
encoder: nn.Module,
|
136 |
+
generator_dict: dict[str, nn.Module],
|
137 |
+
exstyle_dict: dict[str, dict[str, np.ndarray]],
|
138 |
transform: Callable,
|
139 |
device: torch.device,
|
140 |
+
) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image,
|
141 |
+
PIL.Image, PIL.Image]:
|
142 |
+
generator = generator_dict[style_type]
|
143 |
+
exstyles = exstyle_dict[style_type]
|
144 |
+
|
145 |
+
style_id = int(style_id)
|
146 |
+
style_id = min(max(0, style_id), len(exstyles) - 1)
|
147 |
+
|
148 |
stylename = list(exstyles.keys())[style_id]
|
149 |
|
150 |
image = align_face(filepath=image.name, predictor=dlib_landmark_model)
|
|
|
188 |
img_gen1 = postprocess(img_gen[1])
|
189 |
img_gen2 = postprocess(img_gen2[0])
|
190 |
|
191 |
+
try:
|
192 |
+
style_image_dir = pathlib.Path(style_type)
|
193 |
+
style_image = PIL.Image.open(style_image_dir / stylename)
|
194 |
+
except Exception:
|
195 |
+
style_image = None
|
196 |
|
197 |
return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
|
198 |
|
|
|
203 |
args = parse_args()
|
204 |
device = torch.device(args.device)
|
205 |
|
206 |
+
style_types = [
|
207 |
+
'cartoon',
|
208 |
+
'caricature',
|
209 |
+
'anime',
|
210 |
+
'arcane',
|
211 |
+
'comic',
|
212 |
+
'pixar',
|
213 |
+
'slamdunk',
|
214 |
+
]
|
215 |
+
generator_dict = {
|
216 |
+
style_type: load_generator(style_type, device)
|
217 |
+
for style_type in style_types
|
218 |
+
}
|
219 |
+
exstyle_dict = {
|
220 |
+
style_type: load_exstylecode(style_type)
|
221 |
+
for style_type in style_types
|
222 |
+
}
|
223 |
|
224 |
download_cartoon_images()
|
225 |
dlib_landmark_model = create_dlib_landmark_model()
|
226 |
encoder = load_encoder(device)
|
|
|
|
|
227 |
transform = create_transform()
|
228 |
|
229 |
func = functools.partial(run,
|
230 |
dlib_landmark_model=dlib_landmark_model,
|
231 |
encoder=encoder,
|
232 |
+
generator_dict=generator_dict,
|
233 |
+
exstyle_dict=exstyle_dict,
|
234 |
transform=transform,
|
235 |
+
device=device)
|
|
|
236 |
func = functools.update_wrapper(func, run)
|
237 |
|
238 |
repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
|
239 |
title = 'williamyang1991/DualStyleGAN'
|
240 |
description = f"""A demo for {repo_url}
|
241 |
|
242 |
+
You can select style images for cartoon from the table below.
|
243 |
"""
|
244 |
article = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
|
245 |
|
246 |
image_paths = sorted(pathlib.Path('images').glob('*'))
|
247 |
+
examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
|
248 |
|
249 |
gr.Interface(
|
250 |
func,
|
251 |
[
|
252 |
+
gr.inputs.Image(type='file', label='Input Image'),
|
253 |
+
gr.inputs.Radio(
|
254 |
+
style_types,
|
255 |
+
type='value',
|
256 |
+
default='cartoon',
|
257 |
+
label='Style Type',
|
258 |
+
),
|
259 |
+
gr.inputs.Number(default=26, label='Style Image Index'),
|
260 |
],
|
261 |
[
|
262 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|