Masking / app.py
FSFM-C3's picture
Update app.py
fa37ca3 verified
# pip uninstall nvidia_cublas_cu11
import sys
sys.path.append('..')
import os
os.system(f'pip install dlib')
import torch
import numpy as np
from PIL import Image
import models_mae
from torch.nn import functional as F
import dlib
import gradio as gr
# loading model
model = getattr(models_mae, 'mae_vit_base_patch16')()
class ITEM:
def __init__(self, img, parsing_map):
self.image = img
self.parsing_map = parsing_map
face_to_show = ITEM(None, None)
check_region = {'Eyebrows': [2, 3],
'Eyes': [4, 5],
'Nose': [6],
'Mouth': [7, 8, 9],
'Face Boundaries': [10, 1, 0],
'Hair': [10],
'Skin': [1],
'Background': [0]}
def get_boundingbox(face, width, height, minsize=None):
"""
From FF++:
https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py
Expects a dlib face to generate a quadratic bounding box.
:param face: dlib face class
:param width: frame width
:param height: frame height
:param cfg.face_scale: bounding box size multiplier to get a bigger face region
:param minsize: set minimum bounding box size
:return: x, y, bounding_box_size in opencv form
"""
x1 = face.left()
y1 = face.top()
x2 = face.right()
y2 = face.bottom()
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
if minsize:
if size_bb < minsize:
size_bb = minsize
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
# Check for out of bounds, x-y top left corner
x1 = max(int(center_x - size_bb // 2), 0)
y1 = max(int(center_y - size_bb // 2), 0)
# Check for too big bb size for given x, y
size_bb = min(width - x1, size_bb)
size_bb = min(height - y1, size_bb)
return x1, y1, size_bb
def extract_face(frame):
face_detector = dlib.get_frontal_face_detector()
image = np.array(frame.convert('RGB'))
faces = face_detector(image, 1)
if len(faces) > 0:
# For now only take the biggest face
face = faces[0]
# Face crop and rescale(follow FF++)
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
# Get the landmarks/parts for the face in box d only with the five key points
cropped_face = image[y:y + size, x:x + size]
# cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC)
return Image.fromarray(cropped_face)
else:
return None
from torchvision.transforms import transforms
def show_one_img_patchify(img, model):
x = torch.tensor(img)
# make it a batch-like
x = x.unsqueeze(dim=0)
x = torch.einsum('nhwc->nchw', x)
x_patches = model.patchify(x)
# visualize the img_patchify
n = int(np.sqrt(x_patches.shape[1]))
image_size = int(224/n)
padding = 3
new_img = Image.new('RGB', (n * image_size + padding*(n-1), n * image_size + padding*(n-1)), 'white')
for i, patch in enumerate(x_patches[0]):
ax = i % n
ay = int(i / n)
patch_img_tensor = torch.reshape(patch, (model.patch_embed.patch_size[0], model.patch_embed.patch_size[1], 3))
patch_img_tensor = torch.einsum('hwc->chw', patch_img_tensor)
patch_img = transforms.ToPILImage()(patch_img_tensor)
new_img.paste(patch_img, (ax * image_size + padding * ax, ay * image_size + padding * ay))
new_img = new_img.resize((224, 224), Image.BICUBIC)
return new_img
def show_one_img_parchify_mask(img, parsing_map, mask, model):
mask = mask.detach()
mask_patches = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3)
mask = model.unpatchify(mask_patches) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
# visualize mask
vis_mask = mask[0].clone()
vis_mask[vis_mask == 1] = 1 # gray for masked
vis_mask[vis_mask == 2] = -1 # black for highlight masked facial region
vis_mask[vis_mask == 0] = 2 # white for visible
vis_mask = torch.clip(vis_mask * 127, 0, 255).int()
fasking_mask = vis_mask.numpy().astype(np.uint8)
fasking_mask = Image.fromarray(fasking_mask)
# visualize the masked image
im_masked = img
im_masked[mask[0] == 1] = 127
im_masked[mask[0] == 2] = 0
im_masked = Image.fromarray(im_masked)
# visualize the masked image_patchify
parsing_map_masked = parsing_map
parsing_map_masked[mask[0] == 1] = 127
parsing_map_masked[mask[0] == 2] = 0
return [show_one_img_patchify(parsing_map_masked, model), fasking_mask, im_masked]
# Random
class CollateFn_Random:
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
self.img_size = input_size
self.patch_size = patch_size
self.num_patches_axis = input_size // patch_size
self.num_patches = (input_size // patch_size) ** 2
self.mask_ratio = mask_ratio
def __call__(self, image, parsing_map):
random_mask = torch.zeros(parsing_map.size(0), self.num_patches, dtype=torch.float32) # torch.Size([BS, 14, 14])
random_mask = self.masking(parsing_map, random_mask)
return {'image': image, 'random_mask': random_mask}
def masking(self, parsing_map, random_mask):
"""
:return:
"""
for i in range(random_mask.size(0)):
# normalize the masking to strictly target percentage for batch computation.
num_mask_to_change = int(self.mask_ratio * self.num_patches)
mask_change_to = 1 if num_mask_to_change >= 0 else 0
change_indices = torch.randperm(self.num_patches)
for idx in range(num_mask_to_change):
random_mask[i, change_indices[idx]] = mask_change_to
return random_mask
def do_random_masking(image, parsing_map_vis, ratio):
img = torch.from_numpy(image)
img = img.unsqueeze(0).permute(0, 3, 1, 2)
parsing_map = face_to_show.parsing_map
parsing_map = torch.tensor(parsing_map)
mask_method = CollateFn_Random(input_size=224, patch_size=16, mask_ratio=ratio)
mask = mask_method(img, parsing_map)['random_mask']
random_patch_on_parsing, random_mask, random_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model)
return random_patch_on_parsing, random_mask, random_mask_on_image
# Fasking
class CollateFn_Fasking:
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
self.img_size = input_size
self.patch_size = patch_size
self.num_patches_axis = input_size // patch_size
self.num_patches = (input_size // patch_size) ** 2
self.mask_ratio = mask_ratio
# --------------------------------------------------------------------------
self.facial_region_group = [
[2, 4], # right eye
[3, 5], # left eye
[6], # nose
[7, 8, 9], # mouth
[10], # hair
[1], # skin
[0] # background
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
def __call__(self, image, parsing_map):
# image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
fasking_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, 14, 14])
fasking_mask = self.fasking(parsing_map, fasking_mask)
return {'image': image, 'fasking_mask': fasking_mask}
def fasking(self, parsing_map, fasking_mask):
"""
:return:
"""
for i in range(parsing_map.size(0)):
terminate = False
for seg_group in self.facial_region_group[:-2]:
if terminate:
break
for comp_value in seg_group:
fasking_mask[i] = torch.maximum(
fasking_mask[i], F.max_pool2d((parsing_map[i].unsqueeze(0) == comp_value).float(), kernel_size=self.patch_size))
if fasking_mask[i].mean() >= ((self.mask_ratio * self.num_patches) / self.num_patches):
terminate = True
break
fasking_mask = fasking_mask.view(parsing_map.size(0), -1)
for i in range(fasking_mask.size(0)):
# normalize the masking to strictly target percentage for batch computation.
num_mask_to_change = (self.mask_ratio * self.num_patches - fasking_mask[i].sum(dim=-1)).int()
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
select_indices = (fasking_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
fasking_mask[i, select_indices[change_indices]] = mask_change_to
return fasking_mask
def do_fasking_masking(image, parsing_map_vis, ratio):
img = torch.from_numpy(image)
img = img.unsqueeze(0).permute(0, 3, 1, 2)
parsing_map = face_to_show.parsing_map
parsing_map = torch.tensor(parsing_map)
mask_method = CollateFn_Fasking(input_size=224, patch_size=16, mask_ratio=ratio)
mask = mask_method(img, parsing_map)['fasking_mask']
fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model)
return fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image
# FRP
class CollateFn_FR_P_Masking:
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75):
self.img_size = input_size
self.patch_size = patch_size
self.num_patches_axis = input_size // patch_size
self.num_patches = (input_size // patch_size) ** 2
self.mask_ratio = mask_ratio
self.facial_region_group = [
[2, 3], # eyebrows
[4, 5], # eyes
[6], # nose
[7, 8, 9], # mouth
[10, 1, 0], # face boundaries
[10], # hair
[1], # facial skin
[0] # background
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
def __call__(self, image, parsing_map):
# image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
P_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([BS, 14, 14])
P_mask = self.random_variable_facial_semantics_masking(parsing_map, P_mask)
return {'image': image, 'P_mask': P_mask}
def random_variable_facial_semantics_masking(self, parsing_map, P_mask):
"""
:return:
"""
P_mask = P_mask.view(P_mask.size(0), -1)
for i in range(parsing_map.size(0)):
for seg_group in self.facial_region_group[:-2]:
mask_in_seg_group = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, dtype=torch.float32)
if seg_group == [10, 1, 0]:
patch_hair_bg = F.max_pool2d(
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(),
kernel_size=self.patch_size)
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size)
# skin&hair or skin&bg defined as facial boundaries:
mask_in_seg_group = torch.maximum(mask_in_seg_group,
(patch_hair_bg.bool() & patch_skin.bool()).float())
else:
for comp_value in seg_group:
mask_in_seg_group = torch.maximum(mask_in_seg_group,
F.max_pool2d(
(parsing_map[i].unsqueeze(0) == comp_value).float(),
kernel_size=self.patch_size))
mask_in_seg_group = mask_in_seg_group.view(-1)
# to_mask_patches_in_seg_group = mask_in_seg_group - (mask_in_seg_group & P_mask[i])
to_mask_patches_in_seg_group = (mask_in_seg_group - P_mask[i]) > 0
mask_num = (mask_in_seg_group.sum(dim=-1) * self.mask_ratio -
(mask_in_seg_group.sum(dim=-1)-to_mask_patches_in_seg_group.sum(dim=-1))).int()
if mask_num > 0:
select_indices = (to_mask_patches_in_seg_group == 1).nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[:mask_num]
P_mask[i, select_indices[change_indices]] = 1
num_mask_to_change = (self.mask_ratio * self.num_patches - P_mask[i].sum(dim=-1)).int()
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
select_indices = (P_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
P_mask[i, select_indices[change_indices]] = mask_change_to
return P_mask
def do_FRP_masking(image, parsing_map_vis, ratio):
img = torch.from_numpy(image)
img = img.unsqueeze(0).permute(0, 3, 1, 2)
parsing_map = face_to_show.parsing_map
parsing_map = torch.tensor(parsing_map)
mask_method = CollateFn_FR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio)
masks = mask_method(img, parsing_map)
mask = masks['P_mask']
FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask, model)
return FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image
# CRFR_R
class CollateFn_CRFR_R_Masking:
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'):
self.img_size = input_size
self.patch_size = patch_size
self.num_patches_axis = input_size // patch_size
self.num_patches = (input_size // patch_size) ** 2
self.mask_ratio = mask_ratio
self.facial_region_group = [
[2, 3], # eyebrows
[4, 5], # eyes
[6], # nose
[7, 8, 9], # mouth
[10, 1, 0], # face boundaries
[10], # hair
[1], # facial skin
[0] # background
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
self.random_specific_facial_region = check_region[region]
def __call__(self, image, parsing_map):
# mage = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, dtype=torch.float32) # torch.Size([1, H/P, W/P])
facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask)
# torch.Size([num_patches,]), list
CRFR_R_mask, facial_region_mask = self.random_variable_masking(facial_region_mask)
# torch.Size([num_patches,]), torch.Size([num_patches,])
return {'image': image, 'CRFR_R_mask': CRFR_R_mask, 'fr_mask': facial_region_mask}
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask):
"""
:param parsing_map: [1, img_size, img_size])
:param facial_region_mask: [1, num_patches ** .5, num_patches ** .5]
:return: facial_region_mask, random_specific_facial_region
"""
# random_specific_facial_region = random.choice(self.facial_region_group[:-2])
# random_specific_facial_region = [6] # for test: nose
if self.random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background
# True for hair(10) or bg(0) patches:
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(),
kernel_size=self.patch_size)
# True for skin(1) patches:
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size)
# skin&hair or skin&bg is defined as facial boundaries:
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float()
else:
for facial_region_index in self.random_specific_facial_region:
facial_region_mask = torch.maximum(facial_region_mask,
F.max_pool2d((parsing_map == facial_region_index).float(),
kernel_size=self.patch_size))
return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region
def random_variable_masking(self, facial_region_mask):
CRFR_R_mask = facial_region_mask.clone()
for i in range(facial_region_mask.size(0)):
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int()
mask_change_to = 1 if num_mask_to_change >= 0 else 0
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
CRFR_R_mask[i, select_indices[change_indices]] = mask_change_to
facial_region_mask[i] = CRFR_R_mask[i] if num_mask_to_change < 0 else facial_region_mask[i]
return CRFR_R_mask, facial_region_mask
def do_CRFR_R_masking(image, parsing_map_vis, ratio, region):
img = torch.from_numpy(image)
img = img.unsqueeze(0).permute(0, 3, 1, 2)
parsing_map = face_to_show.parsing_map
parsing_map = torch.tensor(parsing_map)
mask_method = CollateFn_CRFR_R_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region)
masks = mask_method(img, parsing_map)
mask = masks['CRFR_R_mask']
fr_mask = masks['fr_mask']
CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model)
return CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image
# CRFR_P
class CollateFn_CRFR_P_Masking:
def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75, region='Nose'):
self.img_size = input_size
self.patch_size = patch_size
self.num_patches_axis = input_size // patch_size
self.num_patches = (input_size // patch_size) ** 2
self.mask_ratio = mask_ratio
self.facial_region_group = [
[2, 3], # eyebrows
[4, 5], # eyes
[6], # nose
[7, 8, 9], # mouth
[10, 1, 0], # face boundaries
[10], # hair
[1], # facial skin
[0] # background
] # ['background', 'face', 'rb', 'lb', 're', 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
self.random_specific_facial_region = check_region[region]
def __call__(self, image, parsing_map):
# image = torch.stack([sample['image'] for sample in samples]) # torch.Size([bs, 3, 224, 224])
# parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) # torch.Size([bs, 1, 224, 224])
# parsing_map = parsing_map.squeeze(1) # torch.Size([BS, 1, 224, 224]) → torch.Size([BS, 224, 224])
# random select a facial semantic region and get corresponding mask(masking all patches include this region)
facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis,
dtype=torch.float32) # torch.Size([1, H/P, W/P])
facial_region_mask, random_specific_facial_region = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask)
# torch.Size([num_patches,]), list
CRFR_P_mask, facial_region_mask = self.random_variable_masking(parsing_map, facial_region_mask, random_specific_facial_region)
# torch.Size([num_patches,]), torch.Size([num_patches,])
return {'image': image, 'CRFR_P_mask': CRFR_P_mask, 'fr_mask': facial_region_mask}
def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask):
"""
:param parsing_map: [1, img_size, img_size])
:param facial_region_mask: [1, num_patches ** .5, num_patches ** .5]
:return: facial_region_mask, random_specific_facial_region
"""
# random_specific_facial_region = random.choice(self.facial_region_group[:-2])
# random_specific_facial_region = [4, 5] # for test: eyes
if self.random_specific_facial_region == [10, 1, 0]: # facial boundaries, 10-hair 1-skin 0-background
# True for hair(10) or bg(0) patches:
patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), kernel_size=self.patch_size)
# True for skin(1) patches:
patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size)
# skin&hair or skin&bg is defined as facial boundaries:
facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float()
# # True for hair(10) or skin(1) patches:
# patch_hair_face = F.max_pool2d(((parsing_map == 10) + (parsing_map == 1)).float(),
# kernel_size=self.patch_size)
# # True for bg(0) patches:
# patch_bg = F.max_pool2d((parsing_map == 0).float(), kernel_size=self.patch_size)
# # skin&bg or hair&bg defined as facial boundaries:
# facial_region_mask = (patch_hair_face.bool() & patch_bg.bool()).float()
else:
for facial_region_index in self.random_specific_facial_region:
facial_region_mask = torch.maximum(facial_region_mask,
F.max_pool2d((parsing_map == facial_region_index).float(),
kernel_size=self.patch_size))
return facial_region_mask.view(parsing_map.size(0), -1), self.random_specific_facial_region
def random_variable_masking(self, parsing_map, facial_region_mask, random_specific_facial_region):
CRFR_P_mask = facial_region_mask.clone()
other_facial_region_group = [region for region in self.facial_region_group if
region != random_specific_facial_region]
# print(other_facial_region_group)
for i in range(facial_region_mask.size(0)): # iterate each map in BS
num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int()
# mask_change_to = 1 if num_mask_to_change >= 0 else 0
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
# masking patches in other facial regions according to the corresponding ratio
if mask_change_to == 1:
# mask_ratio_other_fr = remain(unmasked) patches should be masked / remain(unmasked) patches
mask_ratio_other_fr = (
num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1)))
masked_patches = facial_region_mask[i].clone()
for other_fr in other_facial_region_group:
to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis,
dtype=torch.float32)
if other_fr == [10, 1, 0]:
patch_hair_bg = F.max_pool2d(
((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(),
kernel_size=self.patch_size)
patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), kernel_size=self.patch_size)
# skin&hair or skin&bg defined as facial boundaries:
to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float()
else:
for facial_region_index in other_fr:
to_mask_patches = torch.maximum(to_mask_patches,
F.max_pool2d((parsing_map[i].unsqueeze(0) == facial_region_index).float(),
kernel_size=self.patch_size))
# ignore already masked patches:
to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0
# to_mask_patches = to_mask_patches.view(-1) - (to_mask_patches.view(-1) & masked_patches)
select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[
:torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()]
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to
# prevent overlap
masked_patches = masked_patches + to_mask_patches.float()
# mask/unmask patch from other facial regions to get CRFR_P_mask with fixed size
num_mask_to_change = (self.mask_ratio * self.num_patches - CRFR_P_mask[i].sum(dim=-1)).int()
# mask_change_to = 1 if num_mask_to_change >= 0 else 0
mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item()
# prevent unmasking facial_region_mask
select_indices = ((CRFR_P_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to
else:
# if the num of facial_region_mask is over (num_patches*mask_ratio),
# unmask it to get CRFR_P_mask with fixed size
select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1)
change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)]
CRFR_P_mask[i, select_indices[change_indices]] = mask_change_to
facial_region_mask[i] = CRFR_P_mask[i]
return CRFR_P_mask, facial_region_mask
def do_CRFR_P_masking(image, parsing_map_vis, ratio, region):
img = torch.from_numpy(image)
img = img.unsqueeze(0).permute(0, 3, 1, 2)
parsing_map = face_to_show.parsing_map
parsing_map = torch.tensor(parsing_map)
mask_method = CollateFn_CRFR_P_Masking(input_size=224, patch_size=16, mask_ratio=ratio, region=region)
masks = mask_method(img, parsing_map)
mask = masks['CRFR_P_mask']
fr_mask = masks['fr_mask']
CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image = show_one_img_parchify_mask(image, parsing_map_vis, mask+fr_mask, model)
return CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image
def vis_parsing_maps(parsing_anno):
part_colors = [[255, 255, 255],
[0, 0, 255], [255, 128, 0], [255, 255, 0],
[0, 255, 0], [0, 255, 128],
[0, 255, 255], [255, 0, 255], [255, 0, 128],
[128, 0, 255], [255, 0, 0]]
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
num_of_class = np.max(vis_parsing_anno)
for pi in range(1, num_of_class + 1):
index = np.where(vis_parsing_anno == pi)
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
return vis_parsing_anno_color
#from facer import facer
import facer
def do_face_parsing(img):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
face_detector = facer.face_detector('retinaface/mobilenet', device=device, threshold=0.3) # 0.3 for FF++
face_parser = facer.face_parser('farl/lapa/448', device=device) # celebm parser
img = extract_face(img)
with torch.inference_mode():
img = img.resize((224, 224), Image.BICUBIC)
image = torch.from_numpy(np.array(img.convert('RGB')))
image = image.unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
try:
faces = face_detector(image)
faces = face_parser(image, faces)
seg_logits = faces['seg']['logits']
seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
seg_probs = seg_probs.data # torch.Size([1, 11, 224, 224])
parsing = seg_probs.argmax(1) # [1, 224, 224]
parsing_map = parsing.data.cpu().numpy() # [1, 224, 224] int64
parsing_map = parsing_map.astype(np.int8) # smaller space
parsing_map_vis = vis_parsing_maps(parsing_map.squeeze(0))
except KeyError:
return gr.update()
face_to_show.image = img
face_to_show.parsing_map = parsing_map
return img, parsing_map_vis, show_one_img_patchify(parsing_map_vis, model)
# WebUI
with gr.Blocks() as demo:
gr.Markdown("# Masking Strategy Demo")
gr.Markdown(
"This is a demo for different masking strategies:"
)
gr.Markdown(
"<b>Random Masking</b>: Random select patches in the image for masking."
)
gr.Markdown(
"<b>Fasking-I</b>: Use a face parser to divide facial parts into \{left-eye, right-eye, nose, mouth, hair, skin, background\} regions, and then prioritizes non-skin and non-background regions for masking."
)
gr.Markdown(
"<b>FRP</b>: Mask an equal portion of patches in each facial region to the overall masking ratio."
)
gr.Markdown(
"<b>CRFR-R</b>: (1) Cover a Random Faical Region. (2) Random mask other patches."
)
gr.Markdown(
"<b>CRFR-P (in FSFM-C3)</b>: (1) Cover a Random Faical Region. (2) Proportional mask in other regions."
)
with gr.Column():
image = gr.Image(label="Upload your image", type="pil")
image_submit_btn = gr.Button("Face Parsing")
with gr.Row():
ori_image = gr.Image(interactive=False)
parsing_map_vis = gr.Image(interactive=False)
patch_parsing_map = gr.Image(interactive=False)
with gr.Column(): # Random
random_submit_btn = gr.Button("Random Masking")
ratio_random = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for Random Masking")
with gr.Row():
random_patch_on_parsing = gr.Image(interactive=False)
random_mask = gr.Image(interactive=False)
random_mask_on_image = gr.Image(interactive=False)
with gr.Column(): # Fasking-I
fasking_submit_btn = gr.Button("Fasking-I")
ratio_fasking = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for Fasking")
with gr.Row():
fasking_patch_on_parsing = gr.Image(interactive=False)
fasking_mask = gr.Image(interactive=False)
fasking_mask_on_image = gr.Image(interactive=False)
with gr.Column(): # FRP
FRP_submit_btn = gr.Button("FRP")
ratio_FRP = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for FRP")
with gr.Row():
FRP_patch_on_parsing = gr.Image(interactive=False)
FRP_mask = gr.Image(interactive=False)
FRP_mask_on_image = gr.Image(interactive=False)
with gr.Column(): # CRFR-R
CRFR_R_submit_btn = gr.Button("CRFR-R")
ratio_CRFR_R = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for CRFR-R")
mask_region_CRFR_R = gr.Radio(choices=['Eyebrows',
'Eyes',
'Nose',
'Mouth',
'Face Boundaries',
'Hair',
'Skin',
'Background'], value='Eyes', label="Faical Region")
with gr.Row():
CRFR_R_patch_on_parsing = gr.Image(interactive=False)
CRFR_R_mask = gr.Image(interactive=False)
CRFR_R_mask_on_image = gr.Image(interactive=False)
with gr.Column(): # CRFR-P
CRFR_P_submit_btn = gr.Button("CRFR-P (Ours)")
ratio_CRFR_P = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.5, label="Mask Ratio for CRFR-P")
mask_region_CRFR_P = gr.Radio(choices=['Eyebrows',
'Eyes',
'Nose',
'Mouth',
'Face Boundaries',
'Hair',
'Skin',
'Background'], value='Eyes', label="Faical Region")
with gr.Row():
CRFR_P_patch_on_parsing = gr.Image(interactive=False)
CRFR_P_mask = gr.Image(interactive=False)
CRFR_P_mask_on_image = gr.Image(interactive=False)
parseing_map = []
image_submit_btn.click(
fn = do_face_parsing,
inputs=image,
outputs=[ori_image, parsing_map_vis, patch_parsing_map]
)
random_submit_btn.click(
fn = do_random_masking,
inputs=[ori_image, parsing_map_vis, ratio_random],
outputs=[random_patch_on_parsing, random_mask, random_mask_on_image],
)
fasking_submit_btn.click(
fn = do_fasking_masking,
inputs=[ori_image, parsing_map_vis, ratio_fasking],
outputs=[fasking_patch_on_parsing, fasking_mask, fasking_mask_on_image],
)
FRP_submit_btn.click(
fn = do_FRP_masking,
inputs=[ori_image, parsing_map_vis, ratio_FRP],
outputs=[FRP_patch_on_parsing, FRP_mask, FRP_mask_on_image],
)
CRFR_R_submit_btn.click(
fn = do_CRFR_R_masking,
inputs=[ori_image, parsing_map_vis, ratio_CRFR_R, mask_region_CRFR_R],
outputs=[CRFR_R_patch_on_parsing, CRFR_R_mask, CRFR_R_mask_on_image],
)
CRFR_P_submit_btn.click(
fn = do_CRFR_P_masking,
inputs=[ori_image, parsing_map_vis, ratio_CRFR_P, mask_region_CRFR_P],
outputs=[CRFR_P_patch_on_parsing, CRFR_P_mask, CRFR_P_mask_on_image],
)
if __name__ == "__main__":
gr.close_all()
demo.queue()
demo.launch()