Spaces:
Configuration error
Configuration error
File size: 2,715 Bytes
d3fbdbe 3fdb1c1 d3fbdbe 52c8d3f d3fbdbe 12316d5 d3fbdbe 457d44c d3fbdbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
import numpy as np
import torch.nn as nn
from scipy.ndimage import binary_dilation
model_body = None
extractor_body = None
model_face = None
extractor_face = None
def init_body():
global model_body, extractor_body
extractor_body = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model_body = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
def init_face():
global model_face, extractor_face
extractor_face = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model_face = AutoModelForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to("cuda")
def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
if face:
inputs = extractor_face(images=img, return_tensors="pt").to("cuda")
outputs = model_face(**inputs)
else:
inputs = extractor_body(images=img, return_tensors="pt").to("cuda")
outputs = model_body(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=img.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
if inverse:
pred_seg[pred_seg == body_part_id ] = 0
else:
pred_seg[pred_seg != body_part_id ] = 0
arr_seg = pred_seg.cpu().numpy().astype("uint8")
arr_seg *= 255
pil_seg = Image.fromarray(arr_seg)
return pil_seg
def get_cropped(img: Image, body_part_id: int, inverse:bool, face:bool):
pil_seg = get_mask(img, body_part_id, inverse, face)
crop_mask_np = np.array(pil_seg.convert('L'))
crop_mask_binary = crop_mask_np > 128
dilated_mask = binary_dilation(
crop_mask_binary, iterations=1)
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
mask = Image.fromarray(np.array(dilated_mask)).convert('L')
im_rgb = img.convert("RGB")
cropped = im_rgb.copy()
cropped.putalpha(mask)
return cropped
def get_blurred_mask(img: Image, body_part_id: int):
pil_seg = get_mask(img, body_part_id)
crop_mask_np = np.array(pil_seg.convert('L'))
crop_mask_binary = crop_mask_np > 128
dilated_mask = binary_dilation(
crop_mask_binary, iterations=25)
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
dilated_mask_blurred = dilated_mask.filter(
ImageFilter.GaussianBlur(radius=4))
return dilated_mask_blurred
|