import gradio as gr import numpy as np import cv2 import os from PIL import Image from scipy.interpolate import PchipInterpolator import torchvision import time from tqdm import tqdm import imageio import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from einops import repeat from pydub import AudioSegment from packaging import version from accelerate.utils import set_seed from transformers import CLIPVisionModelWithProjection from diffusers import AutoencoderKLTemporalDecoder from diffusers.utils.import_utils import is_xformers_available from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel from pipeline.pipeline import FlowControlNetPipeline from models.traj_ctrlnet import FlowControlNet as DragControlNet, CMP_demo from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet from utils.flow_viz import flow_to_image from utils.utils import split_filename, image2arr, image2pil, ensure_dirname output_dir = "Output_audio_driven" ensure_dirname(output_dir) def draw_landmarks_cv2(image, landmarks): for i, point in enumerate(landmarks): cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 0, 255), -1) return image def sample_optical_flow(A, B, h, w): b, l, k, _ = A.shape sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device) mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device) x_coords = A[..., 0].long() y_coords = A[..., 1].long() x_coords = torch.clip(x_coords, 0, h - 1) y_coords = torch.clip(y_coords, 0, w - 1) b_idx = torch.arange(b)[:, None, None].repeat(1, l, k) l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k) sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B mask[b_idx, l_idx, x_coords, y_coords] = 1 mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2) return sparse_optical_flow, mask @torch.no_grad() def get_sparse_flow(landmarks, h, w, t): landmarks = torch.flip(landmarks, dims=[3]) pose_flow = (landmarks - landmarks[:, 0:1].repeat(1, t, 1, 1))[:, 1:] # 前向光流 according_poses = landmarks[:, 0:1].repeat(1, t - 1, 1, 1) pose_flow = torch.flip(pose_flow, dims=[3]) b, t, K, _ = pose_flow.shape sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w) return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3) def sample_inputs_face(first_frame, landmarks): pc, ph, pw = first_frame.shape landmarks = landmarks.unsqueeze(0) pl = landmarks.shape[1] sparse_optical_flow, mask = get_sparse_flow(landmarks, ph, pw, pl) if ph != 384 or pw != 384: first_frame_384 = F.interpolate(first_frame.unsqueeze(0), (384, 384)) # [3, 384, 384] landmarks_384 = torch.zeros_like(landmarks) landmarks_384[:, :, :, 0] = landmarks[:, :, :, 0] / pw * 384 landmarks_384[:, :, :, 1] = landmarks[:, :, :, 1] / ph * 384 sparse_optical_flow_384, mask_384 = get_sparse_flow(landmarks_384, 384, 384, pl) else: first_frame_384, landmarks_384 = first_frame, landmarks sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask controlnet_image = first_frame.unsqueeze(0) return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384 PARTS = [ ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)), ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)), ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)), ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)), ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)), ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)), ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)), ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)), ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)), ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)), ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)), ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)), ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)), ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)), ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)), ] def draw_landmarks(keypoints, h, w): image = np.zeros((h, w, 3)) for name, indices, color in PARTS: indices = np.array(indices) - 1 current_part_keypoints = keypoints[indices] for i in range(len(indices) - 1): x1, y1 = current_part_keypoints[i] x2, y2 = current_part_keypoints[i + 1] cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2) return image def divide_points_afterinterpolate(resized_all_points, motion_brush_mask): k = resized_all_points.shape[0] starts = resized_all_points[:, 0] # [K, 2] in_masks = [] out_masks = [] for i in range(k): x, y = int(starts[i][1]), int(starts[i][0]) if motion_brush_mask[x][y] == 255: in_masks.append(resized_all_points[i]) else: out_masks.append(resized_all_points[i]) in_masks = np.array(in_masks) out_masks = np.array(out_masks) return in_masks, out_masks def get_sparseflow_and_mask_forward( resized_all_points, n_steps, H, W, is_backward_flow=False ): K = resized_all_points.shape[0] starts = resized_all_points[:, 0] interpolated_ends = resized_all_points[:, 1:] s_flow = np.zeros((K, n_steps, H, W, 2)) mask = np.zeros((K, n_steps, H, W)) for k in range(K): for i in range(n_steps): start, end = starts[k], interpolated_ends[k][i] flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1) s_flow[k][i][int(start[1]), int(start[0])] = flow mask[k][i][int(start[1]), int(start[0])] = 1 s_flow = np.sum(s_flow, axis=0) mask = np.sum(mask, axis=0) return s_flow, mask def init_models(pretrained_model_name_or_path, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False): drag_ckpt = "./ckpts/mofa/traj_controlnet" face_ckpt = "./ckpts/mofa/ldmk_controlnet" print('start loading models...') image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16" ) vae = AutoencoderKLTemporalDecoder.from_pretrained( pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16") unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, variant="fp16", ) drag_controlnet = DragControlNet.from_pretrained(drag_ckpt) face_controlnet = FaceControlNet.from_pretrained(face_ckpt) cmp = CMP_demo( './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml', 42000 ).to(device) cmp.requires_grad_(False) # Freeze vae and image_encoder vae.requires_grad_(False) image_encoder.requires_grad_(False) unet.requires_grad_(False) drag_controlnet.requires_grad_(False) face_controlnet.requires_grad_(False) # Move image_encoder and vae to gpu and cast to weight_dtype image_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) drag_controlnet.to(device, dtype=weight_dtype) face_controlnet.to(device, dtype=weight_dtype) if enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): print( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError( "xformers is not available. Make sure it is installed correctly") if allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True pipeline = FlowControlNetPipeline.from_pretrained( pretrained_model_name_or_path, unet=unet, face_controlnet=face_controlnet, drag_controlnet=drag_controlnet, image_encoder=image_encoder, vae=vae, torch_dtype=weight_dtype, ) pipeline = pipeline.to(device) print('models loaded.') return pipeline, cmp def interpolate_trajectory(points, n_points): x = [point[0] for point in points] y = [point[1] for point in points] t = np.linspace(0, 1, len(points)) fx = PchipInterpolator(t, x) fy = PchipInterpolator(t, y) new_t = np.linspace(0, 1, n_points) new_x = fx(new_t) new_y = fy(new_t) new_points = list(zip(new_x, new_y)) return new_points def visualize_drag_v2(background_image_path, splited_tracks, width, height): trajectory_maps = [] background_image = Image.open(background_image_path).convert('RGBA') background_image = background_image.resize((width, height)) w, h = background_image.size transparent_background = np.array(background_image) transparent_background[:, :, -1] = 128 transparent_background = Image.fromarray(transparent_background) # Create a transparent layer with the same size as the background image transparent_layer = np.zeros((h, w, 4)) for splited_track in splited_tracks: if len(splited_track) > 1: splited_track = interpolate_trajectory(splited_track, 16) splited_track = splited_track[:16] for i in range(len(splited_track)-1): start_point = (int(splited_track[i][0]), int(splited_track[i][1])) end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(splited_track)-2: cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length) else: cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2) else: cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) trajectory_maps.append(trajectory_map) return trajectory_maps, transparent_layer class Drag: def __init__(self, device, height, width, model_length): self.device = device pretrained_model_name_or_path = "/apdcephfs/share_1290939/vg_zoo/huggingface/stable-video-diffusion-img2vid-xt-1-1" self.device = 'cuda' self.weight_dtype = torch.float16 self.pipeline, self.cmp = init_models( pretrained_model_name_or_path, weight_dtype=self.weight_dtype, device=self.device, ) self.height = height self.width = width self.model_length = model_length def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None): b, t, c, h, w = frames.shape assert h == 384 and w == 384 frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] cmp_flow = [] for i in range(b*t): tmp_flow = self.cmp.run(frames[i:i+1], sparse_optical_flow[i:i+1], mask[i:i+1]) # [1, 2, 256, 256] cmp_flow.append(tmp_flow) cmp_flow = torch.cat(cmp_flow, dim=0) # [b*13, 2, 256, 256] if brush_mask is not None: brush_mask = torch.from_numpy(brush_mask) / 255. brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype) brush_mask = brush_mask.unsqueeze(0).unsqueeze(0) cmp_flow = cmp_flow * brush_mask cmp_flow = cmp_flow.reshape(b, t, 2, h, w) return cmp_flow def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None): fb, fl, fc, _, _ = pixel_values_384.shape controlnet_flow = self.get_cmp_flow( pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1), sparse_optical_flow_384, mask_384, motion_brush_mask ) if self.height != 384 or self.width != 384: scales = [self.height / 384, self.width / 384] controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width) controlnet_flow[:, :, 0] *= scales[1] controlnet_flow[:, :, 1] *= scales[0] return controlnet_flow @torch.no_grad() def forward_sample(self, save_root, first_frame_path, audio_path, hint_path, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask_384=None, ldmk_mask_mask_origin=None, ctrl_scale_traj=1., ctrl_scale_ldmk=1., ldmk_render='sadtalker'): seed = 42 num_frames = self.model_length set_seed(seed) input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255)) height, width = input_first_frame.shape[-2:] input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) if in_mask_flag: flow_inmask = self.get_flow( input_first_frame_384, input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 ) else: fb, fl = mask_384_inmask.shape[:2] flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) if out_mask_flag: flow_outmask = self.get_flow( input_first_frame_384, input_drag_384_outmask, mask_384_outmask ) else: fb, fl = mask_384_outmask.shape[:2] flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) inmask_no_zero = (flow_inmask != 0).all(dim=2) inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) ldmk_controlnet_flow, ldmk_pose_imgs, landmarks, num_frames = self.get_landmarks(save_root, first_frame_path, audio_path, input_first_frame[0], self.model_length, ldmk_render=ldmk_render) ldmk_flow_len = ldmk_controlnet_flow.shape[1] drag_flow_len = controlnet_flow.shape[1] repeat_num = ldmk_flow_len // drag_flow_len + 1 drag_controlnet_flow = controlnet_flow.repeat(1, repeat_num, 1, 1, 1) drag_controlnet_flow = drag_controlnet_flow[:, :ldmk_flow_len] ldmk_mask_mask_origin = ldmk_mask_mask_origin.unsqueeze(0).unsqueeze(0) # [1, 1, h, w] val_output = self.pipeline( input_first_frame_pil, input_first_frame_pil, ldmk_controlnet_flow, ldmk_pose_imgs, drag_controlnet_flow, ldmk_mask_mask_origin, height=height, width=width, num_frames=num_frames, decode_chunk_size=8, motion_bucket_id=127, fps=7, noise_aug_strength=0.02, ctrl_scale_traj=ctrl_scale_traj, ctrl_scale_ldmk=ctrl_scale_ldmk, ) video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow for i in range(num_frames): img = video_frames[i] video_frames[i] = np.array(img) video_frames = np.array(video_frames) outputs = self.save_video(ldmk_pose_imgs, first_frame_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow) return outputs def save_video(self, pose_imgs, image_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow, outputs=dict()): pose_img_nps = (pose_imgs[0].permute(0, 2, 3, 1).cpu().numpy()*255).astype(np.uint8) cv2_firstframe = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) cv2_hint = cv2.cvtColor(cv2.imread(hint_path), cv2.COLOR_BGR2RGB) viz_landmarks = [] for k in tqdm(range(len(landmarks))): im = draw_landmarks_cv2(video_frames[k].copy(), landmarks[k]) viz_landmarks.append(im) viz_landmarks = np.stack(viz_landmarks) viz_esti_flows = [] for i in range(estimated_flow.shape[1]): temp_flow = estimated_flow[0][i].permute(1, 2, 0) viz_esti_flows.append(flow_to_image(temp_flow)) viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c] viz_drag_flows = [] for i in range(drag_controlnet_flow.shape[1]): temp_flow = drag_controlnet_flow[0][i].permute(1, 2, 0) viz_drag_flows.append(flow_to_image(temp_flow)) viz_drag_flows = [np.uint8(np.ones_like(viz_drag_flows[-1]) * 255)] + viz_drag_flows viz_drag_flows = np.stack(viz_drag_flows) # [t-1, h, w, c] out_nps = [] for plen in range(video_frames.shape[0]): out_nps.append(video_frames[plen]) out_nps = np.stack(out_nps) first_frames = np.stack([cv2_firstframe] * out_nps.shape[0]) hints = np.stack([cv2_hint] * out_nps.shape[0]) total_nps = np.concatenate([ first_frames, hints, viz_drag_flows, viz_esti_flows, pose_img_nps, viz_landmarks, out_nps ], axis=2) video_frames_tensor = torch.from_numpy(video_frames).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['logits_imgs'] = video_frames_tensor outputs['traj_flows'] = torch.from_numpy(viz_drag_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['ldmk_flows'] = torch.from_numpy(viz_esti_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['viz_ldmk'] = torch.from_numpy(pose_img_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['out_with_ldmk'] = torch.from_numpy(viz_landmarks).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['total'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. return outputs @torch.no_grad() def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path): original_width, original_height = self.width, self.height flow_div = self.model_length input_all_points = tracking_points.constructor_args['value'] if len(input_all_points) == 0 or len(input_all_points[-1]) == 1: return np.uint8(np.ones((original_width, original_height, 3))*255) resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] new_resized_all_points = [] new_resized_all_points_384 = [] for tnum in range(len(resized_all_points)): new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) resized_all_points = np.array(new_resized_all_points) resized_all_points_384 = np.array(new_resized_all_points_384) motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) resized_all_points_384_inmask, resized_all_points_384_outmask = \ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) in_mask_flag = False out_mask_flag = False if resized_all_points_384_inmask.shape[0] != 0: in_mask_flag = True input_drag_384_inmask, input_mask_384_inmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_inmask, flow_div - 1, 384, 384 ) else: input_drag_384_inmask, input_mask_384_inmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) if resized_all_points_384_outmask.shape[0] != 0: out_mask_flag = True input_drag_384_outmask, input_mask_384_outmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_outmask, flow_div - 1, 384, 384 ) else: input_drag_384_outmask, input_mask_384_outmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w] input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w] first_frames_transform = transforms.Compose([ lambda x: Image.fromarray(x), transforms.ToTensor(), ]) input_first_frame = image2arr(first_frame_path) input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device) seed = 42 num_frames = flow_div set_seed(seed) input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) if in_mask_flag: flow_inmask = self.get_flow( input_first_frame_384, input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 ) else: fb, fl = mask_384_inmask.shape[:2] flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) if out_mask_flag: flow_outmask = self.get_flow( input_first_frame_384, input_drag_384_outmask, mask_384_outmask ) else: fb, fl = mask_384_outmask.shape[:2] flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) inmask_no_zero = (flow_inmask != 0).all(dim=2) inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) print(controlnet_flow.shape) controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0) viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c] return viz_esti_flows @torch.no_grad() def get_cmp_flow_landmarks(self, frames, sparse_optical_flow, mask): dtype = frames.dtype b, t, c, h, w = sparse_optical_flow.shape assert h == 384 and w == 384 frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] cmp_flow = [] for i in range(b*t): tmp_flow = self.cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float()) # [b*13, 2, 256, 256] cmp_flow.append(tmp_flow) cmp_flow = torch.cat(cmp_flow, dim=0) cmp_flow = cmp_flow.reshape(b, t, 2, h, w) return cmp_flow.to(dtype=dtype) def audio2landmark(self, audio_path, img_path, ldmk_result_dir, ldmk_render=0): if ldmk_render == 'sadtalker': return_code = os.system( f''' python sadtalker_audio2pose/inference.py \ --preprocess full \ --size 256 \ --driven_audio {audio_path} \ --source_image {img_path} \ --result_dir {ldmk_result_dir} \ --facerender pirender \ --verbose \ --face3dvis ''') assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report." elif ldmk_render == 'aniportrait': return_code = os.system( f''' python aniportrait/audio2ldmk.py \ --ref_image_path {img_path} \ --audio_path {audio_path} \ --save_dir {ldmk_result_dir} \ ''' ) assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report." else: assert False return os.path.join(ldmk_result_dir, 'landmarks.npy') def get_landmarks(self, save_root, first_frame_path, audio_path, first_frame, num_frames=25, ldmk_render='sadtalker'): ldmk_dir = os.path.join(save_root, 'landmarks') ldmknpy_dir = self.audio2landmark(audio_path, first_frame_path, ldmk_dir, ldmk_render) landmarks = np.load(ldmknpy_dir) landmarks = landmarks[:num_frames] # [25, 68, 2] flow_len = landmarks.shape[0] ldmk_clip = landmarks.copy() assert ldmk_clip.ndim == 3 ldmk_clip[:, :, 0] = ldmk_clip[:, :, 0] / self.width * 320 ldmk_clip[:, :, 1] = ldmk_clip[:, :, 1] / self.height * 320 pose_imgs = [] for i in range(ldmk_clip.shape[0]): pose_img = draw_landmarks(ldmk_clip[i], 320, 320) pose_img = cv2.resize(pose_img, (self.width, self.height), cv2.INTER_NEAREST) pose_imgs.append(pose_img) pose_imgs = np.array(pose_imgs) pose_imgs = torch.from_numpy(pose_imgs).permute(0, 3, 1, 2).float() / 255. pose_imgs = pose_imgs.unsqueeze(0).to(self.weight_dtype).to(self.device) landmarks = torch.from_numpy(landmarks).to(self.weight_dtype).to(self.device) val_controlnet_image, val_sparse_optical_flow, \ val_mask, val_first_frame_384, \ val_sparse_optical_flow_384, val_mask_384 = sample_inputs_face(first_frame, landmarks) fb, fl, fc, fh, fw = val_sparse_optical_flow.shape val_controlnet_flow = self.get_cmp_flow_landmarks( val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1), val_sparse_optical_flow_384, val_mask_384 ) if fh != 384 or fw != 384: scales = [fh / 384, fw / 384] val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw) val_controlnet_flow[:, :, 0] *= scales[1] val_controlnet_flow[:, :, 1] *= scales[0] val_controlnet_image = val_controlnet_image.unsqueeze(0).repeat(1, fl, 1, 1, 1) return val_controlnet_flow, pose_imgs, landmarks, flow_len def run(self, first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render): timestamp = str(time.time()).split('.')[0] save_name = f"trajscale{ctrl_scale_traj}_ldmkscale{ctrl_scale_ldmk}_{ldmk_render}_ts{timestamp}" save_root = os.path.join(os.path.dirname(audio_path), save_name) os.makedirs(save_root, exist_ok=True) original_width, original_height = self.width, self.height flow_div = self.model_length input_all_points = tracking_points.constructor_args['value'] # print(input_all_points) resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] new_resized_all_points = [] new_resized_all_points_384 = [] for tnum in range(len(resized_all_points)): new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) resized_all_points = np.array(new_resized_all_points) resized_all_points_384 = np.array(new_resized_all_points_384) motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) # ldmk_mask_mask_384 = cv2.resize(ldmk_mask_mask, (384, 384), cv2.INTER_NEAREST) # motion_brush_mask = torch.from_numpy(motion_brush_mask) / 255. # motion_brush_mask = motion_brush_mask.to(self.device) ldmk_mask_mask = torch.from_numpy(ldmk_mask_mask) / 255. ldmk_mask_mask = ldmk_mask_mask.to(self.device) if resized_all_points_384.shape[0] != 0: resized_all_points_384_inmask, resized_all_points_384_outmask = \ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) else: resized_all_points_384_inmask = np.array([]) resized_all_points_384_outmask = np.array([]) in_mask_flag = False out_mask_flag = False if resized_all_points_384_inmask.shape[0] != 0: in_mask_flag = True input_drag_384_inmask, input_mask_384_inmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_inmask, flow_div - 1, 384, 384 ) else: input_drag_384_inmask, input_mask_384_inmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) if resized_all_points_384_outmask.shape[0] != 0: out_mask_flag = True input_drag_384_outmask, input_mask_384_outmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_outmask, flow_div - 1, 384, 384 ) else: input_drag_384_outmask, input_mask_384_outmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2] input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w] input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2] input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w] dir, base, ext = split_filename(first_frame_path) id = base.split('_')[0] image_pil = image2pil(first_frame_path) image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGBA') visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height) motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA') visualized_drag = visualized_drag[0].convert('RGBA') ldmk_mask_viz_pil = Image.fromarray(ldmk_mask_viz.astype(np.uint8)).convert('RGBA') drag_input = Image.alpha_composite(image_pil, visualized_drag) motionbrush_ldmkmask = Image.alpha_composite(motion_brush_viz_pil, ldmk_mask_viz_pil) visualized_drag_brush_ldmk_mask = Image.alpha_composite(drag_input, motionbrush_ldmkmask) first_frames_transform = transforms.Compose([ lambda x: Image.fromarray(x), transforms.ToTensor(), ]) hint_path = os.path.join(save_root, f'hint.png') visualized_drag_brush_ldmk_mask.save(hint_path) first_frames = image2arr(first_frame_path) first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=1).to(self.device) outputs = self.forward_sample( save_root, first_frame_path, audio_path, hint_path, input_drag_384_inmask.to(self.device), input_drag_384_outmask.to(self.device), first_frames.to(self.device), input_mask_384_inmask.to(self.device), input_mask_384_outmask.to(self.device), in_mask_flag, out_mask_flag, motion_brush_mask_384, ldmk_mask_mask, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render=ldmk_render) traj_flow_tensor = outputs['traj_flows'][0] # [25, 3, h, w] ldmk_flow_tensor = outputs['ldmk_flows'][0] # [25, 3, h, w] viz_ldmk_tensor = outputs['viz_ldmk'][0] # [25, 3, h, w] out_with_ldmk_tensor = outputs['out_with_ldmk'][0] # [25, 3, h, w] output_tensor = outputs['logits_imgs'][0] # [25, 3, h, w] total_tensor = outputs['total'][0] # [25, 3, h, w] traj_flows_path = os.path.join(save_root, f'traj_flow.gif') ldmk_flows_path = os.path.join(save_root, f'ldmk_flow.gif') viz_ldmk_path = os.path.join(save_root, f'viz_ldmk.gif') out_with_ldmk_path = os.path.join(save_root, f'output_w_ldmk.gif') outputs_path = os.path.join(save_root, f'output.gif') total_path = os.path.join(save_root, f'total.gif') traj_flows_path_mp4 = os.path.join(save_root, f'traj_flow.mp4') ldmk_flows_path_mp4 = os.path.join(save_root, f'ldmk_flow.mp4') viz_ldmk_path_mp4 = os.path.join(save_root, f'viz_ldmk.mp4') out_with_ldmk_path_mp4 = os.path.join(save_root, f'output_w_ldmk.mp4') outputs_path_mp4 = os.path.join(save_root, f'output.mp4') total_path_mp4 = os.path.join(save_root, f'total.mp4') # print(output_tensor.shape) traj_flow_np = traj_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() ldmk_flow_np = ldmk_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() viz_ldmk_np = viz_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() out_with_ldmk_np = out_with_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() output_np = output_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() total_np = total_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() torchvision.io.write_video( traj_flows_path_mp4, traj_flow_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( ldmk_flows_path_mp4, ldmk_flow_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( viz_ldmk_path_mp4, viz_ldmk_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( out_with_ldmk_path_mp4, out_with_ldmk_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( outputs_path_mp4, output_np, fps=20, video_codec='h264', options={'crf': '10'} ) imageio.mimsave(traj_flows_path, np.uint8(traj_flow_np), fps=20, loop=0) imageio.mimsave(ldmk_flows_path, np.uint8(ldmk_flow_np), fps=20, loop=0) imageio.mimsave(viz_ldmk_path, np.uint8(viz_ldmk_np), fps=20, loop=0) imageio.mimsave(out_with_ldmk_path, np.uint8(out_with_ldmk_np), fps=20, loop=0) imageio.mimsave(outputs_path, np.uint8(output_np), fps=20, loop=0) torchvision.io.write_video(total_path_mp4, total_np, fps=20, video_codec='h264', options={'crf': '10'}) imageio.mimsave(total_path, np.uint8(total_np), fps=20, loop=0) return hint_path, traj_flows_path, ldmk_flows_path, viz_ldmk_path, outputs_path, traj_flows_path_mp4, ldmk_flows_path_mp4, viz_ldmk_path_mp4, outputs_path_mp4 with gr.Blocks() as demo: gr.Markdown("""