File size: 2,935 Bytes
d3fbdbe
 
 
 
 
 
3500945
 
d3fbdbe
9f6cca4
 
d3fbdbe
2d7bec9
9f6cca4
 
 
d3fbdbe
 
3500945
 
 
d3fbdbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3500945
d3fbdbe
3500945
d3fbdbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457d44c
 
d3fbdbe
 
 
 
bdf8261
d3fbdbe
 
 
 
 
 
037f136
 
 
3500945
037f136
3500945
037f136
 
 
3500945
037f136
 
 
 
 
 
 
 
 
 
3500945
 
d3fbdbe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation

import numpy as np
import torch.nn as nn
from scipy.ndimage import binary_dilation
import cv2


model = None
extractor = None

def init():
    global model, extractor
    extractor = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
    model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")


def get_mask(img: Image, body_part_id: int, inverse=False):
    inputs = extractor(images=img, return_tensors="pt").to("cuda")
    outputs = model(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=img.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    pred_seg = upsampled_logits.argmax(dim=1)[0]
    if inverse:
        pred_seg[pred_seg == body_part_id ] = 0
    else:
        pred_seg[pred_seg != body_part_id ] = 0
    arr_seg = pred_seg.cpu().numpy().astype("uint8")
    arr_seg *= 255
    pil_seg = Image.fromarray(arr_seg)
    return pil_seg


def get_cropped(img: Image, body_part_id: int, inverse:bool):
    
    pil_seg = get_mask(img, body_part_id, inverse)
    crop_mask_np = np.array(pil_seg.convert('L'))
    crop_mask_binary = crop_mask_np > 128

    dilated_mask = binary_dilation(
                crop_mask_binary, iterations=1)
    dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))

    mask = Image.fromarray(np.array(dilated_mask)).convert('L')
    im_rgb = img.convert("RGB")

    cropped = im_rgb.copy()
    cropped.putalpha(mask)
    return cropped


def get_blurred_mask(img: Image, body_part_id: int):
    pil_seg = get_mask(img, body_part_id)
    crop_mask_np = np.array(pil_seg.convert('L'))
    crop_mask_binary = crop_mask_np > 128

    dilated_mask = binary_dilation(
                crop_mask_binary, iterations=10)
    dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
    dilated_mask_blurred = dilated_mask.filter(
                ImageFilter.GaussianBlur(radius=4))
    return dilated_mask_blurred


def get_cropped_face(pil_image: Image):
    face = get_cropped(pil_image, 11, False)
    image = np.array(face)
    face_casc = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
    gray = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
    faces = face_casc.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    if len(faces) == 0:
        return pil_image
    x, y, w, h = faces[0]
    cropped_face = image[y:y+h, x:x+w]

    result = Image.new('RGBA', pil_image.size, (255, 255, 255, 0))
    face_pil = Image.fromarray(cropped_face)
    if face_pil.size != (w, h):
        face_pil = face_pil.resize((w, h))

    result.paste(face_pil, (x, y))

    return result