import sys import torch import pickle import cv2 import gradio as gr import numpy as np from PIL import Image from collections import defaultdict from glob import glob from matplotlib import pyplot as plt from matplotlib import animation from easydict import EasyDict as edict from huggingface_hub import hf_hub_download sys.path.append("./rome/") sys.path.append('./DECA') from rome.infer import Infer from rome.src.utils.processing import process_black_shape, tensor2image from rome.src.utils.visuals import mask_errosion # loading models ---- create model repo default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt') default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth') # parser configurations args = edict({ "save_dir": ".", "save_render": True, "model_checkpoint": default_model_path, "modnet_path": default_modnet_path, "random_seed": 0, "debug": False, "verbose": False, "model_image_size": 256, "align_source": True, "align_target": False, "align_scale": 1.25, "use_mesh_deformations": False, "subdivide_mesh": False, "renderer_sigma": 1e-08, "renderer_zfar": 100.0, "renderer_type": "soft_mesh", "renderer_texture_type": "texture_uv", "renderer_normalized_alphas": False, "deca_path": "DECA", "rome_data_dir": "rome/data", "autoenc_cat_alphas": False, "autoenc_align_inputs": False, "autoenc_use_warp": False, "autoenc_num_channels": 64, "autoenc_max_channels": 512, "autoenc_num_groups": 4, "autoenc_num_bottleneck_groups": 0, "autoenc_num_blocks": 2, "autoenc_num_layers": 4, "autoenc_block_type": "bottleneck", "neural_texture_channels": 8, "num_harmonic_encoding_funcs": 6, "unet_num_channels": 64, "unet_max_channels": 512, "unet_num_groups": 4, "unet_num_blocks": 1, "unet_num_layers": 2, "unet_block_type": "conv", "unet_skip_connection_type": "cat", "unet_use_normals_cond": True, "unet_use_vertex_cond": False, "unet_use_uvs_cond": False, "unet_pred_mask": False, "use_separate_seg_unet": True, "norm_layer_type": "gn", "activation_type": "relu", "conv_layer_type": "ws_conv", "deform_norm_layer_type": "gn", "deform_activation_type": "relu", "deform_conv_layer_type": "ws_conv", "unet_seg_weight": 0.0, "unet_seg_type": "bce_with_logits", "deform_face_tightness": 0.0001, "use_whole_segmentation": False, "mask_hair_for_neck": False, "use_hair_from_avatar": False, "use_scalp_deforms": True, "use_neck_deforms": True, "use_basis_deformer": False, "use_unet_deformer": True, "pretrained_encoder_basis_path": "", "pretrained_vertex_basis_path": "", "num_basis": 50, "basis_init": "pca", "num_vertex": 5023, "train_basis": True, "path_to_deca": "DECA", "path_to_linear_hair_model": "data/linear_hair.pth", # N/A "path_to_mobile_model": "data/disp_model.pth", # N/A "n_scalp": 60, "use_distill": False, "use_mobile_version": False, "deformer_path": "data/rome.pth", "output_unet_deformer_feats": 32, "use_deca_details": False, "use_flametex": False, "upsample_type": "nearest", "num_frequencies": 6, "deform_face_scale_coef": 0.0, "device": "cuda" }) # download FLAME and DECA pretrained generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl') deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar') with open(generic_model_path, 'rb') as f: ss = pickle.load(f, encoding='latin1') with open('./DECA/data/generic_model.pkl', 'wb') as out: pickle.dump(ss, out) with open(deca_model_path, "rb") as input: with open('./DECA/data/deca_model.tar', "wb") as out: for line in input: out.write(line) # load ROME inference model infer = Infer(args) def image_inference( source_img: gr.inputs.Image = None, driver_img: gr.inputs.Image = None ): out = infer.evaluate(source_img, driver_img, crop_center=False) res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(), out['source_information']['data_dict']['target_img'][0].cpu(), out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2)) return res[..., ::-1] def extract_frames(driver_vid): image_frames = [] vid = cv2.VideoCapture(driver_vid) # path to mp4 while True: success, img = vid.read() if not success: break img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(img) image_frames.append(pil_img) return image_frames def video_inference(source_img, driver_vid): image_frames = extract_frames(driver_vid) resulted_imgs = defaultdict(list) video_folder = 'jenya_driver/' image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4])) mask_hard_threshold = 0.5 N = len(image_frames)//20 for i in range(0, N, 4): new_out = infer.evaluate(source_img, Image.open(image_frames[i]), source_information_for_reuse=out.get('source_information')) mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float() mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255) render = new_out['pred_target_img'].cpu() * (mask_pred) + (1 - mask_pred) normals = process_black_shape(((new_out['pred_target_normal'][0].cpu() + 1) / 2 * mask_pred + (1 - mask_pred) ) ) normals[normals==0.5]=1. resulted_imgs['res_normal'].append(tensor2image(normals)) resulted_imgs['res_mesh_images'].append(tensor2image(new_out['pred_target_shape_img'][0])) resulted_imgs['res_renders'].append(tensor2image(render[0])) video = np.array(resulted_imgs['res_renders']) fig = plt.figure() im = plt.imshow(video[0,:,:,::-1]) plt.axis('off') plt.close() # this is required to not display the generated image def init(): im.set_data(video[0,:,:,::-1]) def animate(i): im.set_data(video[i,:,:,::-1]) return im anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=30) return anim with gr.Blocks() as demo: gr.Markdown("# **

ROME: Realistic one-shot mesh-based head avatars

**") gr.Markdown( """

Create a personal avatar from just a single image using ROME.
Paper | Project Page | Github

""" ) with gr.Tab("Image Inference"): with gr.Row(): source_img = gr.Image(type="pil", label="source image", show_label=True) driver_img = gr.Image(type="pil", label="driver image", show_label=True) image_output = gr.Image() image_button = gr.Button("Predict") with gr.Tab("Video Inference"): with gr.Row(): source_img2 = gr.Image(type="pil", label="source image", show_label=True) driver_vid = gr.Video(label="driver video") video_output = gr.Image() video_button = gr.Button("Predict") gr.Examples( examples=[ ["./examples/lincoln.jpg", "./examples/taras2.jpg"], ["./examples/lincoln.jpg", "./examples/taras1.jpg"] ], inputs=[source_img, driver_img], outputs=[image_output], fn=image_inference, cache_examples=True ) image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output) video_button.click(video_inference, inputs=[source_img2, driver_vid], outputs=video_output) demo.launch()