OutfitChanger / segmentation.py
altayavci's picture
Update segmentation.py
457d44c
raw
history blame
2.72 kB
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
import numpy as np
import torch.nn as nn
from scipy.ndimage import binary_dilation
model_body = None
extractor_body = None
model_face = None
extractor_face = None
def init_body():
global model_body, extractor_body
extractor_body = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model_body = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
def init_face():
global model_face, extractor_face
extractor_face = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model_face = AutoModelForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to("cuda")
def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
if face:
inputs = extractor_face(images=img, return_tensors="pt").to("cuda")
outputs = model_face(**inputs)
else:
inputs = extractor_body(images=img, return_tensors="pt").to("cuda")
outputs = model_body(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=img.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
if inverse:
pred_seg[pred_seg == body_part_id ] = 0
else:
pred_seg[pred_seg != body_part_id ] = 0
arr_seg = pred_seg.cpu().numpy().astype("uint8")
arr_seg *= 255
pil_seg = Image.fromarray(arr_seg)
return pil_seg
def get_cropped(img: Image, body_part_id: int, inverse:bool, face:bool):
pil_seg = get_mask(img, body_part_id, inverse, face)
crop_mask_np = np.array(pil_seg.convert('L'))
crop_mask_binary = crop_mask_np > 128
dilated_mask = binary_dilation(
crop_mask_binary, iterations=1)
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
mask = Image.fromarray(np.array(dilated_mask)).convert('L')
im_rgb = img.convert("RGB")
cropped = im_rgb.copy()
cropped.putalpha(mask)
return cropped
def get_blurred_mask(img: Image, body_part_id: int):
pil_seg = get_mask(img, body_part_id)
crop_mask_np = np.array(pil_seg.convert('L'))
crop_mask_binary = crop_mask_np > 128
dilated_mask = binary_dilation(
crop_mask_binary, iterations=25)
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
dilated_mask_blurred = dilated_mask.filter(
ImageFilter.GaussianBlur(radius=4))
return dilated_mask_blurred