|
|
|
|
|
import sys |
|
sys.path.append('..') |
|
import os |
|
os.system(f'pip install dlib') |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import models_mae |
|
from torch.nn import functional as F |
|
import dlib |
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
model = getattr(models_mae, 'mae_vit_base_patch16')() |
|
|
|
|
|
class ITEM: |
|
def __init__(self, img, parsing_map): |
|
self.image = img |
|
self.parsing_map = parsing_map |
|
|
|
face_to_show = ITEM(None, None) |
|
|
|
|
|
check_region = {'Eyebrows': [2, 3], |
|
'Eyes': [4, 5], |
|
'Nose': [6], |
|
'Mouth': [7, 8, 9], |
|
'Face Boundaries': [10, 1, 0], |
|
'Hair': [10], |
|
'Skin': [1], |
|
'Background': [0]} |
|
|
|
|
|
def get_boundingbox(face, width, height, minsize=None): |
|
""" |
|
From FF++: |
|
https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py |
|
Expects a dlib face to generate a quadratic bounding box. |
|
:param face: dlib face class |
|
:param width: frame width |
|
:param height: frame height |
|
:param cfg.face_scale: bounding box size multiplier to get a bigger face region |
|
:param minsize: set minimum bounding box size |
|
:return: x, y, bounding_box_size in opencv form |
|
""" |
|
x1 = face.left() |
|
y1 = face.top() |
|
x2 = face.right() |
|
y2 = face.bottom() |
|
size_bb = int(max(x2 - x1, y2 - y1) * 1.3) |
|
if minsize: |
|
if size_bb < minsize: |
|
size_bb = minsize |
|
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 |
|
|
|
|
|
x1 = max(int(center_x - size_bb // 2), 0) |
|
y1 = max(int(center_y - size_bb // 2), 0) |
|
|
|
size_bb = min(width - x1, size_bb) |
|
size_bb = min(height - y1, size_bb) |
|
|
|
return x1, y1, size_bb |
|
|
|
|
|
def extract_face(frame): |
|
face_detector = dlib.get_frontal_face_detector() |
|
image = np.array(frame.convert('RGB')) |
|
faces = face_detector(image, 1) |
|
if len(faces) > 0: |
|
|
|
face = faces[0] |
|
|
|
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0]) |
|
|
|
cropped_face = image[y:y + size, x:x + size] |
|
|
|
return Image.fromarray(cropped_face) |
|
else: |
|
return None |
|
|
|
|
|
from torchvision.transforms import transforms |
|
def show_one_img_patchify(img, model): |
|
x = torch.tensor(img) |
|
|
|
|
|
x = x.unsqueeze(dim=0) |
|
x = torch.einsum('nhwc->nchw', x) |
|
x_patches = model.patchify(x) |
|
|
|
|
|
n = int(np.sqrt(x_patches.shape[1])) |
|
image_size = int(224/n) |
|
padding = 3 |
|
new_img = Image.new('RGB', (n * image_size + padding*(n-1), n * image_size + padding*(n-1)), 'white') |
|
for i, patch in enumerate(x_patches[0]): |
|
ax = i % n |
|
ay = int(i / n) |
|
patch_img_tensor = torch.reshape(patch, (model.patch_embed.patch_size[0], model.patch_embed.patch_size[1], 3)) |
|
patch_img_tensor = torch.einsum('hwc->chw', patch_img_tensor) |
|
patch_img = transforms.ToPILImage()(patch_img_tensor) |
|
new_img.paste(patch_img, (ax * image_size + padding * ax, ay * image_size + padding * ay)) |
|
|
|
new_img = new_img.resize((224, 224), Image.BICUBIC) |
|
return new_img |
|
|
|
|
|
def show_one_img_parchify_mask(img, parsing_map, mask, model): |
|
mask = mask.detach() |
|
mask_patches = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) |
|
mask = model.unpatchify(mask_patches) |
|
mask = torch.einsum('nchw->nhwc', mask).detach().cpu() |
|
|
|
|
|
vis_mask = mask[0].clone() |
|
vis_mask[vis_mask == 1] = 1 |
|
vis_mask[vis_mask == 2] = -1 |
|
vis_mask[vis_mask == 0] = 2 |
|
vis_mask = torch.clip(vis_mask * 127, 0, 255).int() |
|
fasking_mask = vis_mask.numpy().astype(np.uint8) |
|
fasking_mask = Image.fromarray(fasking_mask) |
|
|
|
|
|
im_masked = img |
|
im_masked[mask[0] == 1] = 127 |
|
im_masked[mask[0] == 2] = 0 |
|
im_masked = Image.fromarray(im_masked) |
|
|
|
|
|
parsing_map_masked = parsing_map |
|
parsing_map_masked[mask[0] == 1] = 127 |
|
parsing_map_masked[mask[0] == 2] = 0 |
|
|
|
return [show_one_img_patchify(parsing_map_masked, model), fasking_mask, im_masked] |
|
|
|
|
|
|
|
class CollateFn_Random: |
|
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): |
|
self.img_size = input_size |
|
self.patch_size = patch_size |
|
self.num_patches_axis = input_size // patch_size |
|
self.num_patches = (input_size // patch_size) ** 2 |
|
self.mask_ratio = mask_ratio |
|
|
|
def __call__(self, image, parsing_map): |
|
random_mask = torch.zeros(parsing_map.size(0), self.num_patches, dtype=torch.float32) |
|
random_mask = self.masking(parsing_map, random_mask) |
|
|
|
return {'image': image, 'random_mask': random_mask} |
|
|
|
def masking(self, parsing_map, random_mask): |
|
""" |
|
:return: |
|
""" |
|
for i in range(random_mask.size(0)): |
|
|
|
num_mask_to_change = int(self.mask_ratio * self.num_patches) |
|
mask_change_to = 1 if num_mask_to_change >= 0 else 0 |
|
change_indices = torch.randperm(self.num_patches) |
|
for idx in range(num_mask_to_change): |
|
random_mask[i, change_indices[idx]] = mask_change_to |
|
|
|
return random_mask |
|
|
|
|
|
def do_random_masking(image, parsing_map_vis, ratio): |
|
img = torch.from_numpy(image) |
|
img = img.unsqueeze(0).permute(0, 3, 1, 2) |
|
parsing_map = face_to_show.parsing_map |
|
parsing_map = torch.tensor(parsing_map) |
|
|
|
mask_method = CollateFn_Random(input_size=224, patch_size=16, mask_ratio=ratio) |
|
mask = mask_method(img, parsing_map)['random_mask'] |
|
|
|
random_patch_on_parsing, random_mask, random_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model) |
|
|
|
return random_patch_on_parsing, random_mask, random_mask_on_image |
|
|
|
|
|
|
|
class CollateFn_Fasking: |
|
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): |
|
self.img_size = input_size |
|
self.patch_size = patch_size |
|
self.num_patches_axis = input_size // patch_size |
|
self.num_patches = (input_size // patch_size) ** 2 |
|
self.mask_ratio = mask_ratio |
|
|
|
self.facial_region_group = [ |
|
[2, 4], |
|
[3, 5], |
|
[6], |
|
[7, 8, 9], |
|
[10], |
|
[1], |
|
[0] |
|
] |
|
|
|
def __call__(self, image, parsing_map): |
|
|
|
|
|
|
|
|
|
|
|
fasking_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) |
|
fasking_mask = self.fasking(parsing_map, fasking_mask) |
|
|
|
return {'image': image, 'fasking_mask': fasking_mask} |
|
|
|
def fasking(self, parsing_map, fasking_mask): |
|
""" |
|
:return: |
|
""" |
|
for i in range(parsing_map.size(0)): |
|
terminate = False |
|
for seg_group in self.facial_region_group[:-2]: |
|
if terminate: |
|
break |
|
for comp_value in seg_group: |
|
fasking_mask[i] = torch.maximum( |
|
fasking_mask[i], F.max_pool2d((parsing_map[i].unsqueeze(0) == comp_value).float(), kernel_size=self.patch_size)) |
|
if fasking_mask[i].mean() >= ((self.mask_ratio * self.num_patches) / self.num_patches): |
|
terminate = True |
|
break |
|
|
|
fasking_mask = fasking_mask.view(parsing_map.size(0), -1) |
|
for i in range(fasking_mask.size(0)): |
|
|
|
num_mask_to_change = (self.mask_ratio * self.num_patches - fasking_mask[i].sum(dim=-1)).int() |
|
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
select_indices = (fasking_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
fasking_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
return fasking_mask |
|
|
|
|
|
def do_fasking_masking(image, parsing_map_vis, ratio): |
|
img = torch.from_numpy(image) |
|
img = img.unsqueeze(0).permute(0, 3, 1, 2) |
|
parsing_map = face_to_show.parsing_map |
|
parsing_map = torch.tensor(parsing_map) |
|
|
|
mask_method = CollateFn_Fasking(input_size=224, patch_size=16, mask_ratio=ratio) |
|
mask = mask_method(img, parsing_map)['fasking_mask'] |
|
|
|
fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model) |
|
|
|
return fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image |
|
|
|
|
|
|
|
class CollateFn_FR_P_Masking: |
|
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): |
|
self.img_size = input_size |
|
self.patch_size = patch_size |
|
self.num_patches_axis = input_size // patch_size |
|
self.num_patches = (input_size // patch_size) ** 2 |
|
self.mask_ratio = mask_ratio |
|
self.facial_region_group = [ |
|
[2, 3], |
|
[4, 5], |
|
[6], |
|
[7, 8, 9], |
|
[10, 1, 0], |
|
[10], |
|
[1], |
|
[0] |
|
] |
|
|
|
def __call__(self, image, parsing_map): |
|
|
|
|
|
|
|
|
|
|
|
P_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) |
|
P_mask = self.random_variable_facial_semantics_masking(parsing_map, P_mask) |
|
|
|
return {'image': image, 'P_mask': P_mask} |
|
|
|
def random_variable_facial_semantics_masking(self, parsing_map, P_mask): |
|
""" |
|
:return: |
|
""" |
|
P_mask = P_mask.view(P_mask.size(0), -1) |
|
for i in range(parsing_map.size(0)): |
|
|
|
for seg_group in self.facial_region_group[:-2]: |
|
mask_in_seg_group = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) |
|
if seg_group == [10, 1, 0]: |
|
patch_hair_bg = F.max_pool2d( |
|
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), |
|
kernel_size=self.patch_size) |
|
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size) |
|
|
|
mask_in_seg_group = torch.maximum(mask_in_seg_group, |
|
(patch_hair_bg.bool() & patch_skin.bool()).float()) |
|
else: |
|
for comp_value in seg_group: |
|
mask_in_seg_group = torch.maximum(mask_in_seg_group, |
|
F.max_pool2d( |
|
(parsing_map[i].unsqueeze(0) == comp_value).float(), |
|
kernel_size=self.patch_size)) |
|
|
|
mask_in_seg_group = mask_in_seg_group.view(-1) |
|
|
|
to_mask_patches_in_seg_group = (mask_in_seg_group - P_mask[i]) > 0 |
|
mask_num = (mask_in_seg_group.sum(dim=-1) * self.mask_ratio - |
|
(mask_in_seg_group.sum(dim=-1)-to_mask_patches_in_seg_group.sum(dim=-1))).int() |
|
if mask_num > 0: |
|
select_indices = (to_mask_patches_in_seg_group == 1).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:mask_num] |
|
P_mask[i, select_indices[change_indices]] = 1 |
|
|
|
num_mask_to_change = (self.mask_ratio * self.num_patches - P_mask[i].sum(dim=-1)).int() |
|
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
select_indices = (P_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
P_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
return P_mask |
|
|
|
|
|
def do_FRP_masking(image, parsing_map_vis, ratio): |
|
img = torch.from_numpy(image) |
|
img = img.unsqueeze(0).permute(0, 3, 1, 2) |
|
parsing_map = face_to_show.parsing_map |
|
parsing_map = torch.tensor(parsing_map) |
|
|
|
mask_method = CollateFn_FR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio) |
|
masks = mask_method(img, parsing_map) |
|
mask = masks['P_mask'] |
|
|
|
FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model) |
|
|
|
return FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image |
|
|
|
|
|
|
|
class CollateFn_CRFR_R_Masking: |
|
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'): |
|
self.img_size = input_size |
|
self.patch_size = patch_size |
|
self.num_patches_axis = input_size // patch_size |
|
self.num_patches = (input_size // patch_size) ** 2 |
|
self.mask_ratio = mask_ratio |
|
self.facial_region_group = [ |
|
[2, 3], |
|
[4, 5], |
|
[6], |
|
[7, 8, 9], |
|
[10, 1, 0], |
|
[10], |
|
[1], |
|
[0] |
|
] |
|
self.random_specific_facial_region = check_region[region] |
|
|
|
def __call__(self, image, parsing_map): |
|
|
|
|
|
|
|
|
|
|
|
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) |
|
facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) |
|
|
|
|
|
CRFR_R_mask, facial_region_mask = self.random_variable_masking(facial_region_mask) |
|
|
|
|
|
return {'image': image, 'CRFR_R_mask': CRFR_R_mask, 'fr_mask': facial_region_mask} |
|
|
|
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask): |
|
""" |
|
:param parsing_map: [1, img_size, img_size]) |
|
:param facial_region_mask: [1, num_patches ** .5, num_patches ** .5] |
|
:return: facial_region_mask, random_specific_facial_region |
|
""" |
|
|
|
|
|
if self.random_specific_facial_region == [10, 1, 0]: |
|
|
|
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), |
|
kernel_size=self.patch_size) |
|
|
|
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) |
|
|
|
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() |
|
else: |
|
for facial_region_index in self.random_specific_facial_region: |
|
facial_region_mask = torch.maximum(facial_region_mask, |
|
F.max_pool2d((parsing_map == facial_region_index).float(), |
|
kernel_size=self.patch_size)) |
|
|
|
return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region |
|
|
|
def random_variable_masking(self, facial_region_mask): |
|
CRFR_R_mask = facial_region_mask.clone() |
|
|
|
for i in range(facial_region_mask.size(0)): |
|
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() |
|
mask_change_to = 1 if num_mask_to_change >= 0 else 0 |
|
|
|
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
CRFR_R_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
facial_region_mask[i] = CRFR_R_mask[i] if num_mask_to_change < 0 else facial_region_mask[i] |
|
|
|
return CRFR_R_mask, facial_region_mask |
|
|
|
|
|
def do_CRFR_R_masking(image, parsing_map_vis, ratio, region): |
|
img = torch.from_numpy(image) |
|
img = img.unsqueeze(0).permute(0, 3, 1, 2) |
|
parsing_map = face_to_show.parsing_map |
|
parsing_map = torch.tensor(parsing_map) |
|
|
|
mask_method = CollateFn_CRFR_R_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region) |
|
masks = mask_method(img, parsing_map) |
|
mask = masks['CRFR_R_mask'] |
|
fr_mask = masks['fr_mask'] |
|
|
|
CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model) |
|
|
|
return CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image |
|
|
|
|
|
|
|
class CollateFn_CRFR_P_Masking: |
|
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'): |
|
self.img_size = input_size |
|
self.patch_size = patch_size |
|
self.num_patches_axis = input_size // patch_size |
|
self.num_patches = (input_size // patch_size) ** 2 |
|
self.mask_ratio = mask_ratio |
|
|
|
self.facial_region_group = [ |
|
[2, 3], |
|
[4, 5], |
|
[6], |
|
[7, 8, 9], |
|
[10, 1, 0], |
|
[10], |
|
[1], |
|
[0] |
|
] |
|
self.random_specific_facial_region = check_region[region] |
|
|
|
def __call__(self, image, parsing_map): |
|
|
|
|
|
|
|
|
|
|
|
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, |
|
dtype=torch.float32) |
|
facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) |
|
|
|
|
|
CRFR_P_mask, facial_region_mask = self.random_variable_masking(parsing_map, facial_region_mask, random_specific_facial_region) |
|
|
|
|
|
return {'image': image, 'CRFR_P_mask': CRFR_P_mask, 'fr_mask': facial_region_mask} |
|
|
|
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask): |
|
""" |
|
:param parsing_map: [1, img_size, img_size]) |
|
:param facial_region_mask: [1, num_patches ** .5, num_patches ** .5] |
|
:return: facial_region_mask, random_specific_facial_region |
|
""" |
|
|
|
|
|
if self.random_specific_facial_region == [10, 1, 0]: |
|
|
|
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), kernel_size=self.patch_size) |
|
|
|
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) |
|
|
|
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
for facial_region_index in self.random_specific_facial_region: |
|
facial_region_mask = torch.maximum(facial_region_mask, |
|
F.max_pool2d((parsing_map == facial_region_index).float(), |
|
kernel_size=self.patch_size)) |
|
|
|
return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region |
|
|
|
def random_variable_masking(self, parsing_map, facial_region_mask, random_specific_facial_region): |
|
CRFR_P_mask = facial_region_mask.clone() |
|
other_facial_region_group = [region for region in self.facial_region_group if |
|
region != random_specific_facial_region] |
|
|
|
for i in range(facial_region_mask.size(0)): |
|
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() |
|
|
|
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
|
|
|
|
if mask_change_to == 1: |
|
|
|
mask_ratio_other_fr = ( |
|
num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1))) |
|
|
|
masked_patches = facial_region_mask[i].clone() |
|
for other_fr in other_facial_region_group: |
|
to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, |
|
dtype=torch.float32) |
|
if other_fr == [10, 1, 0]: |
|
patch_hair_bg = F.max_pool2d( |
|
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), |
|
kernel_size=self.patch_size) |
|
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size) |
|
|
|
to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float() |
|
else: |
|
for facial_region_index in other_fr: |
|
to_mask_patches = torch.maximum(to_mask_patches, |
|
F.max_pool2d((parsing_map[i].unsqueeze(0) == facial_region_index).float(), |
|
kernel_size=self.patch_size)) |
|
|
|
|
|
to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0 |
|
|
|
select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[ |
|
:torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()] |
|
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
masked_patches = masked_patches + to_mask_patches.float() |
|
|
|
|
|
num_mask_to_change = (self.mask_ratio * self.num_patches - CRFR_P_mask[i].sum(dim=-1)).int() |
|
|
|
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
|
|
select_indices = ((CRFR_P_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
|
else: |
|
|
|
|
|
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
|
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
|
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to |
|
facial_region_mask[i] = CRFR_P_mask[i] |
|
|
|
return CRFR_P_mask, facial_region_mask |
|
|
|
|
|
def do_CRFR_P_masking(image, parsing_map_vis, ratio, region): |
|
img = torch.from_numpy(image) |
|
img = img.unsqueeze(0).permute(0, 3, 1, 2) |
|
parsing_map = face_to_show.parsing_map |
|
parsing_map = torch.tensor(parsing_map) |
|
|
|
mask_method = CollateFn_CRFR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region) |
|
masks = mask_method(img, parsing_map) |
|
mask = masks['CRFR_P_mask'] |
|
fr_mask = masks['fr_mask'] |
|
|
|
CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model) |
|
|
|
return CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image |
|
|
|
|
|
def vis_parsing_maps(parsing_anno): |
|
part_colors = [[255, 255, 255], |
|
[0, 0, 255], [255, 128, 0], [255, 255, 0], |
|
[0, 255, 0], [0, 255, 128], |
|
[0, 255, 255], [255, 0, 255], [255, 0, 128], |
|
[128, 0, 255], [255, 0, 0]] |
|
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) |
|
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 |
|
|
|
num_of_class = np.max(vis_parsing_anno) |
|
|
|
for pi in range(1, num_of_class + 1): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] |
|
|
|
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) |
|
return vis_parsing_anno_color |
|
|
|
|
|
|
|
import facer |
|
def do_face_parsing(img): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
face_detector = facer.face_detector('retinaface/mobilenet', device=device, threshold=0.3) |
|
face_parser = facer.face_parser('farl/lapa/448', device=device) |
|
|
|
img = extract_face(img) |
|
with torch.inference_mode(): |
|
img = img.resize((224, 224), Image.BICUBIC) |
|
image = torch.from_numpy(np.array(img.convert('RGB'))) |
|
image = image.unsqueeze(0).permute(0, 3, 1, 2).to(device=device) |
|
try: |
|
faces = face_detector(image) |
|
faces = face_parser(image, faces) |
|
|
|
seg_logits = faces['seg']['logits'] |
|
seg_probs = seg_logits.softmax(dim=1) |
|
seg_probs = seg_probs.data |
|
parsing = seg_probs.argmax(1) |
|
|
|
parsing_map = parsing.data.cpu().numpy() |
|
parsing_map = parsing_map.astype(np.int8) |
|
parsing_map_vis = vis_parsing_maps(parsing_map.squeeze(0)) |
|
|
|
except KeyError: |
|
return gr.update() |
|
|
|
face_to_show.image = img |
|
face_to_show.parsing_map = parsing_map |
|
return img, parsing_map_vis, show_one_img_patchify(parsing_map_vis, model) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Masking Strategy Demo") |
|
gr.Markdown( |
|
"This is a demo for different masking strategies:" |
|
) |
|
gr.Markdown( |
|
"<b>Random Masking</b>: Random select patches in the image for masking." |
|
) |
|
gr.Markdown( |
|
"<b>Fasking-I</b>: Use a face parser to divide facial parts into \{left-eye, right-eye, nose, mouth, hair, skin, background\} regions, and then prioritizes non-skin and non-background regions for masking." |
|
) |
|
gr.Markdown( |
|
"<b>FRP</b>: Mask an equal portion of patches in each facial region to the overall masking ratio." |
|
) |
|
gr.Markdown( |
|
"<b>CRFR-R</b>: (1) Cover a Random Faical Region. (2) Random mask other patches." |
|
) |
|
gr.Markdown( |
|
"<b>CRFR-P (in FSFM-C3)</b>: (1) Cover a Random Faical Region. (2) Proportional mask in other regions." |
|
) |
|
|
|
with gr.Column(): |
|
image = gr.Image(label="Upload your image", type="pil") |
|
image_submit_btn = gr.Button("Face Parsing") |
|
with gr.Row(): |
|
ori_image = gr.Image(interactive=False) |
|
parsing_map_vis = gr.Image(interactive=False) |
|
patch_parsing_map = gr.Image(interactive=False) |
|
|
|
with gr.Column(): |
|
random_submit_btn = gr.Button("Random Masking") |
|
ratio_random = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for Random Masking") |
|
with gr.Row(): |
|
random_patch_on_parsing = gr.Image(interactive=False) |
|
random_mask = gr.Image(interactive=False) |
|
random_mask_on_image = gr.Image(interactive=False) |
|
|
|
with gr.Column(): |
|
fasking_submit_btn = gr.Button("Fasking-I") |
|
ratio_fasking = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for Fasking") |
|
with gr.Row(): |
|
fasking_patch_on_parsing = gr.Image(interactive=False) |
|
fasking_mask = gr.Image(interactive=False) |
|
fasking_mask_on_image = gr.Image(interactive=False) |
|
|
|
with gr.Column(): |
|
FRP_submit_btn = gr.Button("FRP") |
|
ratio_FRP = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for FRP") |
|
with gr.Row(): |
|
FRP_patch_on_parsing = gr.Image(interactive=False) |
|
FRP_mask = gr.Image(interactive=False) |
|
FRP_mask_on_image = gr.Image(interactive=False) |
|
|
|
with gr.Column(): |
|
CRFR_R_submit_btn = gr.Button("CRFR-R") |
|
ratio_CRFR_R = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for CRFR-R") |
|
mask_region_CRFR_R = gr.Radio(choices=['Eyebrows', |
|
'Eyes', |
|
'Nose', |
|
'Mouth', |
|
'Face Boundaries', |
|
'Hair', |
|
'Skin', |
|
'Background'], value='Eyes', label="Faical Region") |
|
with gr.Row(): |
|
CRFR_R_patch_on_parsing = gr.Image(interactive=False) |
|
CRFR_R_mask = gr.Image(interactive=False) |
|
CRFR_R_mask_on_image = gr.Image(interactive=False) |
|
|
|
with gr.Column(): |
|
CRFR_P_submit_btn = gr.Button("CRFR-P (Ours)") |
|
ratio_CRFR_P = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for CRFR-P") |
|
mask_region_CRFR_P = gr.Radio(choices=['Eyebrows', |
|
'Eyes', |
|
'Nose', |
|
'Mouth', |
|
'Face Boundaries', |
|
'Hair', |
|
'Skin', |
|
'Background'], value='Eyes', label="Faical Region") |
|
with gr.Row(): |
|
CRFR_P_patch_on_parsing = gr.Image(interactive=False) |
|
CRFR_P_mask = gr.Image(interactive=False) |
|
CRFR_P_mask_on_image = gr.Image(interactive=False) |
|
|
|
parseing_map = [] |
|
image_submit_btn.click( |
|
fn = do_face_parsing, |
|
inputs=image, |
|
outputs=[ori_image, parsing_map_vis, patch_parsing_map] |
|
) |
|
random_submit_btn.click( |
|
fn = do_random_masking, |
|
inputs=[ori_image, parsing_map_vis, ratio_random], |
|
outputs=[random_patch_on_parsing, random_mask, random_mask_on_image], |
|
) |
|
fasking_submit_btn.click( |
|
fn = do_fasking_masking, |
|
inputs=[ori_image, parsing_map_vis, ratio_fasking], |
|
outputs=[fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image], |
|
) |
|
FRP_submit_btn.click( |
|
fn = do_FRP_masking, |
|
inputs=[ori_image, parsing_map_vis, ratio_FRP], |
|
outputs=[FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image], |
|
) |
|
CRFR_R_submit_btn.click( |
|
fn = do_CRFR_R_masking, |
|
inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R], |
|
outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image], |
|
) |
|
CRFR_P_submit_btn.click( |
|
fn = do_CRFR_P_masking, |
|
inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P], |
|
outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image], |
|
) |
|
|
|
if __name__ == "__main__": |
|
gr.close_all() |
|
demo.queue() |
|
demo.launch() |