it@M InnovationLab
Fix demo launch again.
d6054e5
raw
history blame
2.12 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFilter
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
import torchvision.transforms
import torch
person_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_cityscapes_swin_large")
person_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_cityscapes_swin_large")
transform = torchvision.transforms.ToPILImage()
def detect_person(image: Image):
semantic_inputs = person_processor(images=image, task_inputs=["semantic"], return_tensors="pt")
semantic_outputs = person_model(**semantic_inputs)
predicted_semantic_map = person_processor.post_process_semantic_segmentation(semantic_outputs, target_sizes=[image.size[::-1]])[0]
mask = transform(predicted_semantic_map.to(torch.uint8))
mask = Image.eval(mask, lambda x: 0 if x == 11 else 255)
return mask
def detect_dummy(image: Image):
return Image.new(mode="L", size=image.size, color=255)
detectors = {
"Person": detect_person,
# "License Plate": detect_license_plate
}
def anonymize(path: str, detectors: list):
# Read image
image = Image.open(path)
# Run requested detectors
masks = [implemented_detectors.get(det, detect_dummy)(image) for det in detectors]
# Combine masks
combined = np.minimum.reduce([np.array(m) for m in masks])
mask = Image.fromarray(combined)
# Apply blur through mask
blurred = image.filter(ImageFilter.GaussianBlur(15))
anonymized = Image.composite(image, blurred, mask)
return anonymized
def test_gradio(image):
masks = [detect_person(image)]
combined = np.minimum.reduce([np.array(m) for m in masks])
mask = Image.fromarray(combined)
# Apply blur through mask
blurred = image.filter(ImageFilter.GaussianBlur(15))
anonymized = Image.composite(image, blurred, mask)
return anonymized
demo = gr.Interface(fn=test_gradio, inputs=gr.Image(source="webcam", type="pil"), outputs=gr.Image(type="pil"))
demo.launch()
# demo.launch(server_name="localhost", server_port=8080)