import gradio as gr import numpy as np import cv2 import os from PIL import Image from scipy.interpolate import PchipInterpolator import torchvision import time from tqdm import tqdm import imageio import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from einops import repeat from pydub import AudioSegment from packaging import version from accelerate.utils import set_seed from transformers import CLIPVisionModelWithProjection from diffusers import AutoencoderKLTemporalDecoder from diffusers.utils.import_utils import is_xformers_available from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel from pipeline.pipeline import FlowControlNetPipeline from models.traj_ctrlnet import FlowControlNet as DragControlNet, CMP_demo from models.ldmk_ctrlnet import FlowControlNet as FaceControlNet from utils.flow_viz import flow_to_image from utils.utils import split_filename, image2arr, image2pil, ensure_dirname output_dir = "Output_audio_driven" ensure_dirname(output_dir) def draw_landmarks_cv2(image, landmarks): for i, point in enumerate(landmarks): cv2.circle(image, (int(point[0]), int(point[1])), 2, (0, 0, 255), -1) return image def sample_optical_flow(A, B, h, w): b, l, k, _ = A.shape sparse_optical_flow = torch.zeros((b, l, h, w, 2), dtype=B.dtype, device=B.device) mask = torch.zeros((b, l, h, w), dtype=torch.uint8, device=B.device) x_coords = A[..., 0].long() y_coords = A[..., 1].long() x_coords = torch.clip(x_coords, 0, h - 1) y_coords = torch.clip(y_coords, 0, w - 1) b_idx = torch.arange(b)[:, None, None].repeat(1, l, k) l_idx = torch.arange(l)[None, :, None].repeat(b, 1, k) sparse_optical_flow[b_idx, l_idx, x_coords, y_coords] = B mask[b_idx, l_idx, x_coords, y_coords] = 1 mask = mask.unsqueeze(-1).repeat(1, 1, 1, 1, 2) return sparse_optical_flow, mask @torch.no_grad() def get_sparse_flow(landmarks, h, w, t): landmarks = torch.flip(landmarks, dims=[3]) pose_flow = (landmarks - landmarks[:, 0:1].repeat(1, t, 1, 1))[:, 1:] # 前向光流 according_poses = landmarks[:, 0:1].repeat(1, t - 1, 1, 1) pose_flow = torch.flip(pose_flow, dims=[3]) b, t, K, _ = pose_flow.shape sparse_optical_flow, mask = sample_optical_flow(according_poses, pose_flow, h, w) return sparse_optical_flow.permute(0, 1, 4, 2, 3), mask.permute(0, 1, 4, 2, 3) def sample_inputs_face(first_frame, landmarks): pc, ph, pw = first_frame.shape landmarks = landmarks.unsqueeze(0) pl = landmarks.shape[1] sparse_optical_flow, mask = get_sparse_flow(landmarks, ph, pw, pl) if ph != 384 or pw != 384: first_frame_384 = F.interpolate(first_frame.unsqueeze(0), (384, 384)) # [3, 384, 384] landmarks_384 = torch.zeros_like(landmarks) landmarks_384[:, :, :, 0] = landmarks[:, :, :, 0] / pw * 384 landmarks_384[:, :, :, 1] = landmarks[:, :, :, 1] / ph * 384 sparse_optical_flow_384, mask_384 = get_sparse_flow(landmarks_384, 384, 384, pl) else: first_frame_384, landmarks_384 = first_frame, landmarks sparse_optical_flow_384, mask_384 = sparse_optical_flow, mask controlnet_image = first_frame.unsqueeze(0) return controlnet_image, sparse_optical_flow, mask, first_frame_384, sparse_optical_flow_384, mask_384 PARTS = [ ('FACE', [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], (10, 200, 10)), ('LEFT_EYE', [43, 44, 45, 46, 47, 48, 43], (180, 200, 10)), ('LEFT_EYEBROW', [23, 24, 25, 26, 27], (180, 220, 10)), ('RIGHT_EYE', [37, 38, 39, 40, 41, 42, 37], (10, 200, 180)), ('RIGHT_EYEBROW', [18, 19, 20, 21, 22], (10, 220, 180)), ('NOSE_UP', [28, 29, 30, 31], (10, 200, 250)), ('NOSE_DOWN', [32, 33, 34, 35, 36], (250, 200, 10)), ('LIPS_OUTER_BOTTOM_LEFT', [55, 56, 57, 58], (10, 180, 20)), ('LIPS_OUTER_BOTTOM_RIGHT', [49, 60, 59, 58], (20, 10, 180)), ('LIPS_INNER_BOTTOM_LEFT', [65, 66, 67], (100, 100, 30)), ('LIPS_INNER_BOTTOM_RIGHT', [61, 68, 67], (100, 150, 50)), ('LIPS_OUTER_TOP_LEFT', [52, 53, 54, 55], (20, 80, 100)), ('LIPS_OUTER_TOP_RIGHT', [52, 51, 50, 49], (80, 100, 20)), ('LIPS_INNER_TOP_LEFT', [63, 64, 65], (120, 100, 200)), ('LIPS_INNER_TOP_RIGHT', [63, 62, 61], (150, 120, 100)), ] def draw_landmarks(keypoints, h, w): image = np.zeros((h, w, 3)) for name, indices, color in PARTS: indices = np.array(indices) - 1 current_part_keypoints = keypoints[indices] for i in range(len(indices) - 1): x1, y1 = current_part_keypoints[i] x2, y2 = current_part_keypoints[i + 1] cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness=2) return image def divide_points_afterinterpolate(resized_all_points, motion_brush_mask): k = resized_all_points.shape[0] starts = resized_all_points[:, 0] # [K, 2] in_masks = [] out_masks = [] for i in range(k): x, y = int(starts[i][1]), int(starts[i][0]) if motion_brush_mask[x][y] == 255: in_masks.append(resized_all_points[i]) else: out_masks.append(resized_all_points[i]) in_masks = np.array(in_masks) out_masks = np.array(out_masks) return in_masks, out_masks def get_sparseflow_and_mask_forward( resized_all_points, n_steps, H, W, is_backward_flow=False ): K = resized_all_points.shape[0] starts = resized_all_points[:, 0] interpolated_ends = resized_all_points[:, 1:] s_flow = np.zeros((K, n_steps, H, W, 2)) mask = np.zeros((K, n_steps, H, W)) for k in range(K): for i in range(n_steps): start, end = starts[k], interpolated_ends[k][i] flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1) s_flow[k][i][int(start[1]), int(start[0])] = flow mask[k][i][int(start[1]), int(start[0])] = 1 s_flow = np.sum(s_flow, axis=0) mask = np.sum(mask, axis=0) return s_flow, mask def init_models(pretrained_model_name_or_path, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False): drag_ckpt = "./ckpts/mofa/traj_controlnet" face_ckpt = "./ckpts/mofa/ldmk_controlnet" print('start loading models...') image_encoder = CLIPVisionModelWithProjection.from_pretrained( pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16" ) vae = AutoencoderKLTemporalDecoder.from_pretrained( pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16") unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, variant="fp16", ) drag_controlnet = DragControlNet.from_pretrained(drag_ckpt) face_controlnet = FaceControlNet.from_pretrained(face_ckpt) cmp = CMP_demo( './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml', 42000 ).to(device) cmp.requires_grad_(False) # Freeze vae and image_encoder vae.requires_grad_(False) image_encoder.requires_grad_(False) unet.requires_grad_(False) drag_controlnet.requires_grad_(False) face_controlnet.requires_grad_(False) # Move image_encoder and vae to gpu and cast to weight_dtype image_encoder.to(device, dtype=weight_dtype) vae.to(device, dtype=weight_dtype) unet.to(device, dtype=weight_dtype) drag_controlnet.to(device, dtype=weight_dtype) face_controlnet.to(device, dtype=weight_dtype) if enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): print( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError( "xformers is not available. Make sure it is installed correctly") if allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True pipeline = FlowControlNetPipeline.from_pretrained( pretrained_model_name_or_path, unet=unet, face_controlnet=face_controlnet, drag_controlnet=drag_controlnet, image_encoder=image_encoder, vae=vae, torch_dtype=weight_dtype, ) pipeline = pipeline.to(device) print('models loaded.') return pipeline, cmp def interpolate_trajectory(points, n_points): x = [point[0] for point in points] y = [point[1] for point in points] t = np.linspace(0, 1, len(points)) fx = PchipInterpolator(t, x) fy = PchipInterpolator(t, y) new_t = np.linspace(0, 1, n_points) new_x = fx(new_t) new_y = fy(new_t) new_points = list(zip(new_x, new_y)) return new_points def visualize_drag_v2(background_image_path, splited_tracks, width, height): trajectory_maps = [] background_image = Image.open(background_image_path).convert('RGBA') background_image = background_image.resize((width, height)) w, h = background_image.size transparent_background = np.array(background_image) transparent_background[:, :, -1] = 128 transparent_background = Image.fromarray(transparent_background) # Create a transparent layer with the same size as the background image transparent_layer = np.zeros((h, w, 4)) for splited_track in splited_tracks: if len(splited_track) > 1: splited_track = interpolate_trajectory(splited_track, 16) splited_track = splited_track[:16] for i in range(len(splited_track)-1): start_point = (int(splited_track[i][0]), int(splited_track[i][1])) end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(splited_track)-2: cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length) else: cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2) else: cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) trajectory_maps.append(trajectory_map) return trajectory_maps, transparent_layer class Drag: def __init__(self, device, height, width, model_length): self.device = device pretrained_model_name_or_path = "/apdcephfs/share_1290939/vg_zoo/huggingface/stable-video-diffusion-img2vid-xt-1-1" self.device = 'cuda' self.weight_dtype = torch.float16 self.pipeline, self.cmp = init_models( pretrained_model_name_or_path, weight_dtype=self.weight_dtype, device=self.device, ) self.height = height self.width = width self.model_length = model_length def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None): b, t, c, h, w = frames.shape assert h == 384 and w == 384 frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] cmp_flow = [] for i in range(b*t): tmp_flow = self.cmp.run(frames[i:i+1], sparse_optical_flow[i:i+1], mask[i:i+1]) # [1, 2, 256, 256] cmp_flow.append(tmp_flow) cmp_flow = torch.cat(cmp_flow, dim=0) # [b*13, 2, 256, 256] if brush_mask is not None: brush_mask = torch.from_numpy(brush_mask) / 255. brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype) brush_mask = brush_mask.unsqueeze(0).unsqueeze(0) cmp_flow = cmp_flow * brush_mask cmp_flow = cmp_flow.reshape(b, t, 2, h, w) return cmp_flow def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None): fb, fl, fc, _, _ = pixel_values_384.shape controlnet_flow = self.get_cmp_flow( pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1), sparse_optical_flow_384, mask_384, motion_brush_mask ) if self.height != 384 or self.width != 384: scales = [self.height / 384, self.width / 384] controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width) controlnet_flow[:, :, 0] *= scales[1] controlnet_flow[:, :, 1] *= scales[0] return controlnet_flow @torch.no_grad() def forward_sample(self, save_root, first_frame_path, audio_path, hint_path, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask_384=None, ldmk_mask_mask_origin=None, ctrl_scale_traj=1., ctrl_scale_ldmk=1., ldmk_render='sadtalker'): seed = 42 num_frames = self.model_length set_seed(seed) input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255)) height, width = input_first_frame.shape[-2:] input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) if in_mask_flag: flow_inmask = self.get_flow( input_first_frame_384, input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 ) else: fb, fl = mask_384_inmask.shape[:2] flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) if out_mask_flag: flow_outmask = self.get_flow( input_first_frame_384, input_drag_384_outmask, mask_384_outmask ) else: fb, fl = mask_384_outmask.shape[:2] flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) inmask_no_zero = (flow_inmask != 0).all(dim=2) inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) ldmk_controlnet_flow, ldmk_pose_imgs, landmarks, num_frames = self.get_landmarks(save_root, first_frame_path, audio_path, input_first_frame[0], self.model_length, ldmk_render=ldmk_render) ldmk_flow_len = ldmk_controlnet_flow.shape[1] drag_flow_len = controlnet_flow.shape[1] repeat_num = ldmk_flow_len // drag_flow_len + 1 drag_controlnet_flow = controlnet_flow.repeat(1, repeat_num, 1, 1, 1) drag_controlnet_flow = drag_controlnet_flow[:, :ldmk_flow_len] ldmk_mask_mask_origin = ldmk_mask_mask_origin.unsqueeze(0).unsqueeze(0) # [1, 1, h, w] val_output = self.pipeline( input_first_frame_pil, input_first_frame_pil, ldmk_controlnet_flow, ldmk_pose_imgs, drag_controlnet_flow, ldmk_mask_mask_origin, height=height, width=width, num_frames=num_frames, decode_chunk_size=8, motion_bucket_id=127, fps=7, noise_aug_strength=0.02, ctrl_scale_traj=ctrl_scale_traj, ctrl_scale_ldmk=ctrl_scale_ldmk, ) video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow for i in range(num_frames): img = video_frames[i] video_frames[i] = np.array(img) video_frames = np.array(video_frames) outputs = self.save_video(ldmk_pose_imgs, first_frame_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow) return outputs def save_video(self, pose_imgs, image_path, hint_path, landmarks, video_frames, estimated_flow, drag_controlnet_flow, outputs=dict()): pose_img_nps = (pose_imgs[0].permute(0, 2, 3, 1).cpu().numpy()*255).astype(np.uint8) cv2_firstframe = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) cv2_hint = cv2.cvtColor(cv2.imread(hint_path), cv2.COLOR_BGR2RGB) viz_landmarks = [] for k in tqdm(range(len(landmarks))): im = draw_landmarks_cv2(video_frames[k].copy(), landmarks[k]) viz_landmarks.append(im) viz_landmarks = np.stack(viz_landmarks) viz_esti_flows = [] for i in range(estimated_flow.shape[1]): temp_flow = estimated_flow[0][i].permute(1, 2, 0) viz_esti_flows.append(flow_to_image(temp_flow)) viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c] viz_drag_flows = [] for i in range(drag_controlnet_flow.shape[1]): temp_flow = drag_controlnet_flow[0][i].permute(1, 2, 0) viz_drag_flows.append(flow_to_image(temp_flow)) viz_drag_flows = [np.uint8(np.ones_like(viz_drag_flows[-1]) * 255)] + viz_drag_flows viz_drag_flows = np.stack(viz_drag_flows) # [t-1, h, w, c] out_nps = [] for plen in range(video_frames.shape[0]): out_nps.append(video_frames[plen]) out_nps = np.stack(out_nps) first_frames = np.stack([cv2_firstframe] * out_nps.shape[0]) hints = np.stack([cv2_hint] * out_nps.shape[0]) total_nps = np.concatenate([ first_frames, hints, viz_drag_flows, viz_esti_flows, pose_img_nps, viz_landmarks, out_nps ], axis=2) video_frames_tensor = torch.from_numpy(video_frames).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['logits_imgs'] = video_frames_tensor outputs['traj_flows'] = torch.from_numpy(viz_drag_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['ldmk_flows'] = torch.from_numpy(viz_esti_flows).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['viz_ldmk'] = torch.from_numpy(pose_img_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['out_with_ldmk'] = torch.from_numpy(viz_landmarks).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. outputs['total'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255. return outputs @torch.no_grad() def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path): original_width, original_height = self.width, self.height flow_div = self.model_length input_all_points = tracking_points.constructor_args['value'] if len(input_all_points) == 0 or len(input_all_points[-1]) == 1: return np.uint8(np.ones((original_width, original_height, 3))*255) resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] new_resized_all_points = [] new_resized_all_points_384 = [] for tnum in range(len(resized_all_points)): new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) resized_all_points = np.array(new_resized_all_points) resized_all_points_384 = np.array(new_resized_all_points_384) motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) resized_all_points_384_inmask, resized_all_points_384_outmask = \ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) in_mask_flag = False out_mask_flag = False if resized_all_points_384_inmask.shape[0] != 0: in_mask_flag = True input_drag_384_inmask, input_mask_384_inmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_inmask, flow_div - 1, 384, 384 ) else: input_drag_384_inmask, input_mask_384_inmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) if resized_all_points_384_outmask.shape[0] != 0: out_mask_flag = True input_drag_384_outmask, input_mask_384_outmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_outmask, flow_div - 1, 384, 384 ) else: input_drag_384_outmask, input_mask_384_outmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w] input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2] input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w] first_frames_transform = transforms.Compose([ lambda x: Image.fromarray(x), transforms.ToTensor(), ]) input_first_frame = image2arr(first_frame_path) input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device) seed = 42 num_frames = flow_div set_seed(seed) input_first_frame_384 = F.interpolate(input_first_frame, (384, 384)) input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0) input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384] mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384] input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype) mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype) input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype) mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype) input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype) if in_mask_flag: flow_inmask = self.get_flow( input_first_frame_384, input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384 ) else: fb, fl = mask_384_inmask.shape[:2] flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) if out_mask_flag: flow_outmask = self.get_flow( input_first_frame_384, input_drag_384_outmask, mask_384_outmask ) else: fb, fl = mask_384_outmask.shape[:2] flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype) inmask_no_zero = (flow_inmask != 0).all(dim=2) inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask) controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask) print(controlnet_flow.shape) controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0) viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c] return viz_esti_flows @torch.no_grad() def get_cmp_flow_landmarks(self, frames, sparse_optical_flow, mask): dtype = frames.dtype b, t, c, h, w = sparse_optical_flow.shape assert h == 384 and w == 384 frames = frames.flatten(0, 1) # [b*13, 3, 256, 256] sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256] mask = mask.flatten(0, 1) # [b*13, 2, 256, 256] cmp_flow = [] for i in range(b*t): tmp_flow = self.cmp.run(frames[i:i+1].float(), sparse_optical_flow[i:i+1].float(), mask[i:i+1].float()) # [b*13, 2, 256, 256] cmp_flow.append(tmp_flow) cmp_flow = torch.cat(cmp_flow, dim=0) cmp_flow = cmp_flow.reshape(b, t, 2, h, w) return cmp_flow.to(dtype=dtype) def audio2landmark(self, audio_path, img_path, ldmk_result_dir, ldmk_render=0): if ldmk_render == 'sadtalker': return_code = os.system( f''' python sadtalker_audio2pose/inference.py \ --preprocess full \ --size 256 \ --driven_audio {audio_path} \ --source_image {img_path} \ --result_dir {ldmk_result_dir} \ --facerender pirender \ --verbose \ --face3dvis ''') assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report." elif ldmk_render == 'aniportrait': return_code = os.system( f''' python aniportrait/audio2ldmk.py \ --ref_image_path {img_path} \ --audio_path {audio_path} \ --save_dir {ldmk_result_dir} \ ''' ) assert return_code == 0, "Errors in generating landmarks! Please trace back up for detailed error report." else: assert False return os.path.join(ldmk_result_dir, 'landmarks.npy') def get_landmarks(self, save_root, first_frame_path, audio_path, first_frame, num_frames=25, ldmk_render='sadtalker'): ldmk_dir = os.path.join(save_root, 'landmarks') ldmknpy_dir = self.audio2landmark(audio_path, first_frame_path, ldmk_dir, ldmk_render) landmarks = np.load(ldmknpy_dir) landmarks = landmarks[:num_frames] # [25, 68, 2] flow_len = landmarks.shape[0] ldmk_clip = landmarks.copy() assert ldmk_clip.ndim == 3 ldmk_clip[:, :, 0] = ldmk_clip[:, :, 0] / self.width * 320 ldmk_clip[:, :, 1] = ldmk_clip[:, :, 1] / self.height * 320 pose_imgs = [] for i in range(ldmk_clip.shape[0]): pose_img = draw_landmarks(ldmk_clip[i], 320, 320) pose_img = cv2.resize(pose_img, (self.width, self.height), cv2.INTER_NEAREST) pose_imgs.append(pose_img) pose_imgs = np.array(pose_imgs) pose_imgs = torch.from_numpy(pose_imgs).permute(0, 3, 1, 2).float() / 255. pose_imgs = pose_imgs.unsqueeze(0).to(self.weight_dtype).to(self.device) landmarks = torch.from_numpy(landmarks).to(self.weight_dtype).to(self.device) val_controlnet_image, val_sparse_optical_flow, \ val_mask, val_first_frame_384, \ val_sparse_optical_flow_384, val_mask_384 = sample_inputs_face(first_frame, landmarks) fb, fl, fc, fh, fw = val_sparse_optical_flow.shape val_controlnet_flow = self.get_cmp_flow_landmarks( val_first_frame_384.unsqueeze(0).repeat(1, fl, 1, 1, 1), val_sparse_optical_flow_384, val_mask_384 ) if fh != 384 or fw != 384: scales = [fh / 384, fw / 384] val_controlnet_flow = F.interpolate(val_controlnet_flow.flatten(0, 1), (fh, fw), mode='nearest').reshape(fb, fl, 2, fh, fw) val_controlnet_flow[:, :, 0] *= scales[1] val_controlnet_flow[:, :, 1] *= scales[0] val_controlnet_image = val_controlnet_image.unsqueeze(0).repeat(1, fl, 1, 1, 1) return val_controlnet_flow, pose_imgs, landmarks, flow_len def run(self, first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render): timestamp = str(time.time()).split('.')[0] save_name = f"trajscale{ctrl_scale_traj}_ldmkscale{ctrl_scale_ldmk}_{ldmk_render}_ts{timestamp}" save_root = os.path.join(os.path.dirname(audio_path), save_name) os.makedirs(save_root, exist_ok=True) original_width, original_height = self.width, self.height flow_div = self.model_length input_all_points = tracking_points.constructor_args['value'] # print(input_all_points) resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points] new_resized_all_points = [] new_resized_all_points_384 = [] for tnum in range(len(resized_all_points)): new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], flow_div)) new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], flow_div)) resized_all_points = np.array(new_resized_all_points) resized_all_points_384 = np.array(new_resized_all_points_384) motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST) # ldmk_mask_mask_384 = cv2.resize(ldmk_mask_mask, (384, 384), cv2.INTER_NEAREST) # motion_brush_mask = torch.from_numpy(motion_brush_mask) / 255. # motion_brush_mask = motion_brush_mask.to(self.device) ldmk_mask_mask = torch.from_numpy(ldmk_mask_mask) / 255. ldmk_mask_mask = ldmk_mask_mask.to(self.device) if resized_all_points_384.shape[0] != 0: resized_all_points_384_inmask, resized_all_points_384_outmask = \ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384) else: resized_all_points_384_inmask = np.array([]) resized_all_points_384_outmask = np.array([]) in_mask_flag = False out_mask_flag = False if resized_all_points_384_inmask.shape[0] != 0: in_mask_flag = True input_drag_384_inmask, input_mask_384_inmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_inmask, flow_div - 1, 384, 384 ) else: input_drag_384_inmask, input_mask_384_inmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) if resized_all_points_384_outmask.shape[0] != 0: out_mask_flag = True input_drag_384_outmask, input_mask_384_outmask = \ get_sparseflow_and_mask_forward( resized_all_points_384_outmask, flow_div - 1, 384, 384 ) else: input_drag_384_outmask, input_mask_384_outmask = \ np.zeros((flow_div - 1, 384, 384, 2)), \ np.zeros((flow_div - 1, 384, 384)) input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2] input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w] input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2] input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w] dir, base, ext = split_filename(first_frame_path) id = base.split('_')[0] image_pil = image2pil(first_frame_path) image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGBA') visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height) motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA') visualized_drag = visualized_drag[0].convert('RGBA') ldmk_mask_viz_pil = Image.fromarray(ldmk_mask_viz.astype(np.uint8)).convert('RGBA') drag_input = Image.alpha_composite(image_pil, visualized_drag) motionbrush_ldmkmask = Image.alpha_composite(motion_brush_viz_pil, ldmk_mask_viz_pil) visualized_drag_brush_ldmk_mask = Image.alpha_composite(drag_input, motionbrush_ldmkmask) first_frames_transform = transforms.Compose([ lambda x: Image.fromarray(x), transforms.ToTensor(), ]) hint_path = os.path.join(save_root, f'hint.png') visualized_drag_brush_ldmk_mask.save(hint_path) first_frames = image2arr(first_frame_path) first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=1).to(self.device) outputs = self.forward_sample( save_root, first_frame_path, audio_path, hint_path, input_drag_384_inmask.to(self.device), input_drag_384_outmask.to(self.device), first_frames.to(self.device), input_mask_384_inmask.to(self.device), input_mask_384_outmask.to(self.device), in_mask_flag, out_mask_flag, motion_brush_mask_384, ldmk_mask_mask, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render=ldmk_render) traj_flow_tensor = outputs['traj_flows'][0] # [25, 3, h, w] ldmk_flow_tensor = outputs['ldmk_flows'][0] # [25, 3, h, w] viz_ldmk_tensor = outputs['viz_ldmk'][0] # [25, 3, h, w] out_with_ldmk_tensor = outputs['out_with_ldmk'][0] # [25, 3, h, w] output_tensor = outputs['logits_imgs'][0] # [25, 3, h, w] total_tensor = outputs['total'][0] # [25, 3, h, w] traj_flows_path = os.path.join(save_root, f'traj_flow.gif') ldmk_flows_path = os.path.join(save_root, f'ldmk_flow.gif') viz_ldmk_path = os.path.join(save_root, f'viz_ldmk.gif') out_with_ldmk_path = os.path.join(save_root, f'output_w_ldmk.gif') outputs_path = os.path.join(save_root, f'output.gif') total_path = os.path.join(save_root, f'total.gif') traj_flows_path_mp4 = os.path.join(save_root, f'traj_flow.mp4') ldmk_flows_path_mp4 = os.path.join(save_root, f'ldmk_flow.mp4') viz_ldmk_path_mp4 = os.path.join(save_root, f'viz_ldmk.mp4') out_with_ldmk_path_mp4 = os.path.join(save_root, f'output_w_ldmk.mp4') outputs_path_mp4 = os.path.join(save_root, f'output.mp4') total_path_mp4 = os.path.join(save_root, f'total.mp4') # print(output_tensor.shape) traj_flow_np = traj_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() ldmk_flow_np = ldmk_flow_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() viz_ldmk_np = viz_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() out_with_ldmk_np = out_with_ldmk_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() output_np = output_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() total_np = total_tensor.permute(0, 2, 3, 1).clamp(0, 1).mul(255).cpu().numpy() torchvision.io.write_video( traj_flows_path_mp4, traj_flow_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( ldmk_flows_path_mp4, ldmk_flow_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( viz_ldmk_path_mp4, viz_ldmk_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( out_with_ldmk_path_mp4, out_with_ldmk_np, fps=20, video_codec='h264', options={'crf': '10'} ) torchvision.io.write_video( outputs_path_mp4, output_np, fps=20, video_codec='h264', options={'crf': '10'} ) imageio.mimsave(traj_flows_path, np.uint8(traj_flow_np), fps=20, loop=0) imageio.mimsave(ldmk_flows_path, np.uint8(ldmk_flow_np), fps=20, loop=0) imageio.mimsave(viz_ldmk_path, np.uint8(viz_ldmk_np), fps=20, loop=0) imageio.mimsave(out_with_ldmk_path, np.uint8(out_with_ldmk_np), fps=20, loop=0) imageio.mimsave(outputs_path, np.uint8(output_np), fps=20, loop=0) torchvision.io.write_video(total_path_mp4, total_np, fps=20, video_codec='h264', options={'crf': '10'}) imageio.mimsave(total_path, np.uint8(total_np), fps=20, loop=0) return hint_path, traj_flows_path, ldmk_flows_path, viz_ldmk_path, outputs_path, traj_flows_path_mp4, ldmk_flows_path_mp4, viz_ldmk_path_mp4, outputs_path_mp4 with gr.Blocks() as demo: gr.Markdown("""

MOFA-Video


""") gr.Markdown("""

Official Gradio Demo for MOFA-Video: Controllable Image Animation via Generative Motion Field Adaptions in Frozen Image-to-Video Diffusion Model.

""") gr.Markdown( """

1. Use the "Upload Image" button to upload an image. Avoid dragging the image directly into the window.

2. Proceed to trajectory control:

2.1. Click "Add Trajectory" first, then select points on the "Add Trajectory Here" image. The first click sets the starting point. Click multiple points to create a non-linear trajectory. To add a new trajectory, click "Add Trajectory" again and select points on the image.
2.2. After adding each trajectory, an optical flow image will be displayed automatically in "Temporary Trajectory Flow Visualization". Use it as a reference to adjust the trajectory for desired effects (e.g., area, intensity).
2.3. To delete the latest trajectory, click "Delete Last Trajectory."
2.4. To use the motion brush for restraining the control area of the trajectory, click to add masks on the "Add Motion Brush Here" image. The motion brush restricts the optical flow area derived from the trajectory whose starting point is within the motion brush. The displayed optical flow image will change correspondingly. Adjust the motion brush radius using the "Motion Brush Radius" slider.
2.5. Choose the Control scale for trajectory using the "Control Scale for Trajectory" slider. This determines the control intensity of trajectory. Setting it to 0 means no control (pure generation result of SVD itself), while setting it to 1 results in the strongest control (which will not lead to good results in most cases because of twisting artifacts). A preset value of 0.6 is recommended for most cases.

3. Proceed to landmark control from audio:

3.1. Use the "Upload Audio" button to upload an audio (currently support .wav and .mp3 extensions).
3.2. Click to add masks on the "Add Landmark Mask Here" image. This mask restricts the optical flow area derived from the landmarks, which should usually covers the area of the person's head parts, and, if desired, body parts for more natural body movement instead of being stationary. Adjust the landmark brush radius using the "Landmark Brush Radius" slider.
3.3. Choose the Control scale for landmarks using the "Control Scale for Landmark" slider. This determines the control intensity of landmarks. Different from trajectory controls, a preset value of 1 is recommended for most cases.
3.4. Choose the landmark renderer to generate landmark sequences from the input audio. The landmark generation codes are based on either SadTalker or AniPortrait. We empirically find that SadTalker provides landmarks that follow the audio more precisely in the lips part, while Aniportrait provides more significant lips movement. Note that while pure landmark-based control of MOFA-Video supports long video generation via the periodic sampling strategy, current version of hybrid control only supports short video generation (25 frames), which means that the first 25 frames of the generated landmark sequences are used to obtain the result.

4. Click the "Run" button to animate the image according to the trajectory and the landmark.

""" ) target_size = 512 # NOTICE: changing to lower resolution may impair the performance of the model. DragNUWA_net = Drag("cuda:0", target_size, target_size, 25) first_frame_path = gr.State() audio_path = gr.State() tracking_points = gr.State([]) motion_brush_points = gr.State([]) motion_brush_mask = gr.State() motion_brush_viz = gr.State() ldmk_mask_mask = gr.State() ldmk_mask_viz = gr.State() def preprocess_image(image): image_pil = image2pil(image.name) raw_w, raw_h = image_pil.size max_edge = min(raw_w, raw_h) resize_ratio = target_size / max_edge image_pil = image_pil.resize((round(raw_w * resize_ratio), round(raw_h * resize_ratio)), Image.BILINEAR) new_w, new_h = image_pil.size crop_w = new_w - (new_w % 64) crop_h = new_h - (new_h % 64) image_pil = transforms.CenterCrop((crop_h, crop_w))(image_pil.convert('RGB')) DragNUWA_net.width = crop_w DragNUWA_net.height = crop_h id = str(time.time()).split('.')[0] os.makedirs(os.path.join(output_dir, str(id)), exist_ok=True) first_frame_path = os.path.join(output_dir, str(id), f"input.png") image_pil.save(first_frame_path) return first_frame_path, first_frame_path, first_frame_path, first_frame_path, gr.State([]), gr.State([]), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4)) def convert_audio_to_wav(input_audio_file, output_wav_file): extension = os.path.splitext(os.path.basename(input_audio_file))[-1] if extension.lower() == ".mp3": audio = AudioSegment.from_mp3(input_audio_file) elif extension.lower() == ".wav": audio = AudioSegment.from_wav(input_audio_file) elif extension.lower() == ".ogg": audio = AudioSegment.from_ogg(input_audio_file) elif extension.lower() == ".flac": audio = AudioSegment.from_file(input_audio_file, "flac") else: raise ValueError(f"Not supported extension: {extension}") audio.export(output_wav_file, format="wav") def save_audio(audio, first_frame_path): assert first_frame_path is not None, "First upload image, then audio!" img_basedir = os.path.dirname(first_frame_path) id = str(time.time()).split('.')[0] audio_path = os.path.join(img_basedir, f'audio_{str(id)}', 'audio.wav') os.makedirs(os.path.dirname(audio_path), exist_ok=True) # os.system(f'cp -r {audio.name} {audio_path}') convert_audio_to_wav(audio.name, audio_path) return audio_path, audio_path def add_drag(tracking_points): if len(tracking_points.constructor_args['value']) != 0 and tracking_points.constructor_args['value'][-1] == []: return tracking_points tracking_points.constructor_args['value'].append([]) return tracking_points def delete_last_drag(tracking_points, first_frame_path, motion_brush_mask): if len(tracking_points.constructor_args['value']) > 0: tracking_points.constructor_args['value'].pop() transparent_background = Image.open(first_frame_path).convert('RGBA') w, h = transparent_background.size transparent_layer = np.zeros((h, w, 4)) for track in tracking_points.constructor_args['value']: if len(track) > 1: for i in range(len(track)-1): start_point = track[i] end_point = track[i+1] vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(track)-2: cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) else: cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) else: cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) return tracking_points, trajectory_map, viz_flow def add_motion_brushes(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, tracking_points, evt: gr.SelectData): transparent_background = Image.open(first_frame_path).convert('RGBA') w, h = transparent_background.size motion_points = motion_brush_points.constructor_args['value'] motion_points.append(evt.index) x, y = evt.index cv2.circle(motion_brush_mask, (x, y), radius, 255, -1) cv2.circle(transparent_layer, (x, y), radius, (128, 0, 128, 127), -1) transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8)) motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil) viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) return motion_brush_mask, transparent_layer, motion_map, viz_flow def add_ldmk_mask(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, evt: gr.SelectData): transparent_background = Image.open(first_frame_path).convert('RGBA') w, h = transparent_background.size motion_points = motion_brush_points.constructor_args['value'] motion_points.append(evt.index) x, y = evt.index cv2.circle(motion_brush_mask, (x, y), radius, 255, -1) cv2.circle(transparent_layer, (x, y), radius, (0, 0, 255, 127), -1) transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8)) motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil) return motion_brush_mask, transparent_layer, motion_map def add_tracking_points(tracking_points, first_frame_path, motion_brush_mask, evt: gr.SelectData): # SelectData is a subclass of EventData print(f"You selected {evt.value} at {evt.index} from {evt.target}") if len(tracking_points.constructor_args['value']) == 0: tracking_points.constructor_args['value'].append([]) tracking_points.constructor_args['value'][-1].append(evt.index) print(tracking_points.constructor_args['value']) transparent_background = Image.open(first_frame_path).convert('RGBA') w, h = transparent_background.size transparent_layer = np.zeros((h, w, 4)) for track in tracking_points.constructor_args['value']: if len(track) > 1: for i in range(len(track)-1): start_point = track[i] end_point = track[i+1] vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(track)-2: cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length) else: cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,) else: cv2.circle(transparent_layer, tuple(track[0]), 3, (255, 0, 0, 255), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path) return tracking_points, trajectory_map, viz_flow with gr.Row(): with gr.Column(scale=3): image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"]) audio_upload_button = gr.UploadButton(label="Upload Audio", file_types=["audio"]) input_audio = gr.Audio(label="Audio") with gr.Column(scale=3): add_drag_button = gr.Button(value="Add Trajectory") delete_last_drag_button = gr.Button(value="Delete Last Trajectory") run_button = gr.Button(value="Run") with gr.Column(scale=3): motion_brush_radius = gr.Slider(label='Motion Brush Radius', minimum=1, maximum=200, step=1, value=10) ldmk_mask_radius = gr.Slider(label='Landmark Brush Radius', minimum=1, maximum=200, step=1, value=10) with gr.Column(scale=3): ctrl_scale_traj = gr.Slider(label='Control Scale for Trajectory', minimum=0, maximum=1., step=0.01, value=0.6) ctrl_scale_ldmk = gr.Slider(label='Control Scale for Landmark', minimum=0, maximum=1., step=0.01, value=1.) ldmk_render = gr.Radio(label='Landmark Renderer', choices=['sadtalker', 'aniportrait'], value='aniportrait') with gr.Column(scale=4): input_image = gr.Image(label="Add Trajectory Here", interactive=True) with gr.Column(scale=4): motion_brush_image = gr.Image(label="Add Motion Brush Here", interactive=True) with gr.Column(scale=4): ldmk_mask_image = gr.Image(label="Add Landmark Mask Here", interactive=True) with gr.Row(): with gr.Column(scale=6): viz_flow = gr.Image(label="Temporary Trajectory Flow Visualization") with gr.Column(scale=6): hint_image = gr.Image(label="Final Hint Image") with gr.Row(): with gr.Column(scale=6): traj_flows_gif = gr.Image(label="Trajectory Flow GIF") with gr.Column(scale=6): ldmk_flows_gif = gr.Image(label="Landmark Flow GIF") with gr.Row(): with gr.Column(scale=6): viz_ldmk_gif = gr.Image(label="Landmark Visualization GIF") with gr.Column(scale=6): outputs_gif = gr.Image(label="Output GIF") with gr.Row(): with gr.Column(scale=6): traj_flows_mp4 = gr.Video(label="Trajectory Flow MP4") with gr.Column(scale=6): ldmk_flows_mp4 = gr.Video(label="Landmark Flow MP4") with gr.Row(): with gr.Column(scale=6): viz_ldmk_mp4 = gr.Video(label="Landmark Visualization MP4") with gr.Column(scale=6): outputs_mp4 = gr.Video(label="Output MP4") image_upload_button.upload(preprocess_image, image_upload_button, [input_image, motion_brush_image, ldmk_mask_image, first_frame_path, tracking_points, motion_brush_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz]) audio_upload_button.upload(save_audio, [audio_upload_button, first_frame_path], [input_audio, audio_path]) add_drag_button.click(add_drag, tracking_points, tracking_points) delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow]) input_image.select(add_tracking_points, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow]) motion_brush_image.select(add_motion_brushes, [motion_brush_points, motion_brush_mask, motion_brush_viz, first_frame_path, motion_brush_radius, tracking_points], [motion_brush_mask, motion_brush_viz, motion_brush_image, viz_flow]) ldmk_mask_image.select(add_ldmk_mask, [motion_brush_points, ldmk_mask_mask, ldmk_mask_viz, first_frame_path, ldmk_mask_radius], [ldmk_mask_mask, ldmk_mask_viz, ldmk_mask_image]) run_button.click(DragNUWA_net.run, [first_frame_path, audio_path, tracking_points, motion_brush_mask, motion_brush_viz, ldmk_mask_mask, ldmk_mask_viz, ctrl_scale_traj, ctrl_scale_ldmk, ldmk_render], [hint_image, traj_flows_gif, ldmk_flows_gif, viz_ldmk_gif, outputs_gif, traj_flows_mp4, ldmk_flows_mp4, viz_ldmk_mp4, outputs_mp4]) # demo.launch(server_name="0.0.0.0", debug=True, server_port=80) demo.launch(server_name="127.0.0.1", debug=True, server_port=9080)