Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from bisnet import BiSeNet | |
from huggingface_hub import snapshot_download | |
from utils import vis_parsing_maps, decode_segmentation_masks, image_to_tensor | |
os.system("pip freeze") | |
REPO_ID = "leonelhs/faceparser" | |
MODEL_NAME = "79999_iter.pth" | |
model = BiSeNet(n_classes=19) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
snapshot_folder = snapshot_download(repo_id=REPO_ID) | |
model_path = os.path.join(snapshot_folder, MODEL_NAME) | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
model.eval() | |
def makeOverlay(image, mask): | |
prediction_mask = np.asarray(mask) | |
image = image.resize((512, 512), Image.BILINEAR) | |
dark_map, overlay = vis_parsing_maps(image, prediction_mask) | |
colormap = decode_segmentation_masks(dark_map) | |
return overlay, colormap | |
def makeMask(image): | |
with torch.no_grad(): | |
image = image.resize((512, 512), Image.BILINEAR) | |
input_tensor = image_to_tensor(image) | |
input_tensor = torch.unsqueeze(input_tensor, 0) | |
if torch.cuda.is_available(): | |
input_tensor = input_tensor.cuda() | |
output = model(input_tensor)[0] | |
return output.squeeze(0).cpu().numpy().argmax(0) | |
def predict(image): | |
mask = makeMask(image) | |
overlay, colormap = makeOverlay(image, mask) | |
return overlay | |
title = "Face Parser" | |
description = r""" | |
## Image face parser for research | |
This is an implementation of <a href='https://github.com/zllrunning/face-parsing.PyTorch' target='_blank'>face-parsing.PyTorch</a>. | |
It has no any particular purpose than start research on AI models. | |
""" | |
article = r""" | |
Questions, doubts, comments, please email 📧 `leonelhs@gmail.com` | |
This demo is running on a CPU, if you like this project please make us a donation to run on a GPU or just give us a <a href='https://github.com/leonelhs/zeroscratches/' target='_blank'>Github ⭐</a> | |
<a href="https://www.buymeacoffee.com/leonelhs"><img src="https://img.buymeacoffee.com/button-api/?text=Buy me a coffee&emoji=&slug=leonelhs&button_colour=FFDD00&font_colour=000000&font_family=Cookie&outline_colour=000000&coffee_colour=ffffff" /></a> | |
<center><img src='https://visitor-badge.glitch.me/badge?page_id=zeroscratches.visitor-badge' alt='visitor badge'></center> | |
""" | |
demo = gr.Interface( | |
predict, [ | |
gr.Image(type="pil", label="Input"), | |
], [ | |
gr.Image(type="numpy", label="Image face parsed") | |
], | |
title=title, | |
description=description, | |
article=article) | |
demo.queue().launch() | |