OutfitChanger / segmentation.py
altayavci's picture
Upload 17 files
d3fbdbe
raw
history blame
2.75 kB
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation
import numpy as np
import torch.nn as nn
from scipy.ndimage import binary_dilation
model_body = None
extractor_body = None
model_face = None
extractor_face = None
def init_body():
global model_body, extractor_body
extractor_body = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
model_body = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")
def init_face():
global model_face, extractor_face
extractor_face = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model_face = AutoModelForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
if face:
inputs = extractor_face(images=img, return_tensors="pt").to("cuda")
outputs = model_face(**inputs)
else:
inputs = extractor_body(images=img, return_tensors="pt").to("cuda")
outputs = model_body(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=img.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
if inverse:
pred_seg[pred_seg == body_part_id ] = 0
else:
pred_seg[pred_seg != body_part_id ] = 0
arr_seg = pred_seg.cpu().numpy().astype("uint8")
arr_seg *= 255
pil_seg = Image.fromarray(arr_seg)
return pil_seg
def get_cropped(img: Image, body_part_id: int, inverse=False): # img openpose gen image olucak
pil_seg = get_mask(img, body_part_id, inverse)
crop_mask_np = np.array(pil_seg.convert('L'))
crop_mask_binary = crop_mask_np > 128
dilated_mask = binary_dilation(
crop_mask_binary, iterations=1)
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
mask = Image.fromarray(np.array(dilated_mask)).convert('L')
im_rgb = img.convert("RGB")
cropped = im_rgb.copy()
cropped.putalpha(mask)
return cropped
def get_blurred_mask(img: Image, body_part_id: int, inverse=False):
pil_seg = get_mask(img, body_part_id, inverse)
crop_mask_np = np.array(pil_seg.convert('L'))
crop_mask_binary = crop_mask_np > 128
dilated_mask = binary_dilation(
crop_mask_binary, iterations=25)
dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
dilated_mask_blurred = dilated_mask.filter(
ImageFilter.GaussianBlur(radius=4))
return dilated_mask_blurred