File size: 2,739 Bytes
d3fbdbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fdb1c1
d3fbdbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52c8d3f
d3fbdbe
12316d5
d3fbdbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation, SegformerImageProcessor, AutoModelForSemanticSegmentation

import numpy as np
import torch.nn as nn
from scipy.ndimage import binary_dilation

model_body = None
extractor_body = None

model_face = None
extractor_face = None

def init_body():
    global model_body, extractor_body
    extractor_body = AutoFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
    model_body = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes").to("cuda")

def init_face():
    global model_face, extractor_face
    extractor_face = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
    model_face = AutoModelForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to("cuda")


def get_mask(img: Image, body_part_id: int, inverse=False, face=False):
    if face:
        inputs = extractor_face(images=img, return_tensors="pt").to("cuda")
        outputs = model_face(**inputs)
    else:
        inputs = extractor_body(images=img, return_tensors="pt").to("cuda")
        outputs = model_body(**inputs)
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=img.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    pred_seg = upsampled_logits.argmax(dim=1)[0]
    if inverse:
        pred_seg[pred_seg == body_part_id ] = 0
    else:
        pred_seg[pred_seg != body_part_id ] = 0
    arr_seg = pred_seg.cpu().numpy().astype("uint8")
    arr_seg *= 255
    pil_seg = Image.fromarray(arr_seg)
    return pil_seg


def get_cropped(img: Image, body_part_id: int, inverse:bool, face:bool):
    
    pil_seg = get_mask(img, body_part_id, inverse, face)
    crop_mask_np = np.array(pil_seg.convert('L'))
    crop_mask_binary = crop_mask_np > 128

    dilated_mask = binary_dilation(
                crop_mask_binary, iterations=1)
    dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))

    mask = Image.fromarray(np.array(dilated_mask)).convert('L')
    im_rgb = img.convert("RGB")

    cropped = im_rgb.copy()
    cropped.putalpha(mask)
    return cropped


def get_blurred_mask(img: Image, body_part_id: int, inverse=False):
    pil_seg = get_mask(img, body_part_id, inverse)
    crop_mask_np = np.array(pil_seg.convert('L'))
    crop_mask_binary = crop_mask_np > 128

    dilated_mask = binary_dilation(
                crop_mask_binary, iterations=25)
    dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))
    dilated_mask_blurred = dilated_mask.filter(
                ImageFilter.GaussianBlur(radius=4))
    return dilated_mask_blurred