Spaces:
Configuration error
Configuration error
from PIL import Image, ImageFilter | |
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation | |
import numpy as np | |
import torch.nn as nn | |
from scipy.ndimage import binary_dilation | |
import cv2 | |
model = None | |
extractor = None | |
def init(): | |
global model, extractor | |
extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda") | |
def get_mask(img: Image, body_part_id: int, inverse=False): | |
inputs = extractor(images=img, return_tensors="pt").to("cuda") | |
outputs = model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=img.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
if inverse: | |
pred_seg[pred_seg == body_part_id ] = 0 | |
else: | |
pred_seg[pred_seg != body_part_id ] = 0 | |
arr_seg = pred_seg.cpu().numpy().astype("uint8") | |
arr_seg *= 255 | |
pil_seg = Image.fromarray(arr_seg) | |
return pil_seg | |
def get_cropped(img: Image, body_part_id: int, inverse:bool): | |
pil_seg = get_mask(img, body_part_id, inverse) | |
crop_mask_np = np.array(pil_seg.convert('L')) | |
crop_mask_binary = crop_mask_np > 128 | |
dilated_mask = binary_dilation( | |
crop_mask_binary, iterations=1) | |
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8)) | |
mask = Image.fromarray(np.array(dilated_mask)).convert('L') | |
im_rgb = img.convert("RGB") | |
cropped = im_rgb.copy() | |
cropped.putalpha(mask) | |
return cropped | |
def get_blurred_mask(img: Image, body_part_id: int): | |
pil_seg = get_mask(img, body_part_id) | |
crop_mask_np = np.array(pil_seg.convert('L')) | |
crop_mask_binary = crop_mask_np > 128 | |
dilated_mask = binary_dilation( | |
crop_mask_binary, iterations=10) | |
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8)) | |
dilated_mask_blurred = dilated_mask.filter( | |
ImageFilter.GaussianBlur(radius=4)) | |
return dilated_mask_blurred | |
def get_cropped_face(pil_image: Image): | |
face = get_cropped(pil_image, 11, False) | |
image = np.array(face) | |
face_casc = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
gray = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY) | |
faces = face_casc.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
if len(faces) == 0: | |
return pil_image | |
x, y, w, h = faces[0] | |
cropped_face = image[y:y+h, x:x+w] | |
result = Image.new('RGBA', pil_image.size, (255, 255, 255, 0)) | |
face_pil = Image.fromarray(cropped_face) | |
if face_pil.size != (w, h): | |
face_pil = face_pil.resize((w, h)) | |
result.paste(face_pil, (x, y)) | |
return result | |