import sys import torch import pickle import cv2 import gradio as gr import numpy as np from PIL import Image from collections import defaultdict from glob import glob from matplotlib import pyplot as plt from matplotlib import animation from easydict import EasyDict as edict from huggingface_hub import hf_hub_download sys.path.append("./rome/") sys.path.append('./DECA') from rome.infer import Infer from rome.src.utils.processing import process_black_shape, tensor2image from rome.src.utils.visuals import mask_errosion # loading models ---- create model repo default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt') default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth') # parser configurations args = edict({ "save_dir": ".", "save_render": True, "model_checkpoint": default_model_path, "modnet_path": default_modnet_path, "random_seed": 0, "debug": False, "verbose": False, "model_image_size": 256, "align_source": True, "align_target": False, "align_scale": 1.25, "use_mesh_deformations": False, "subdivide_mesh": False, "renderer_sigma": 1e-08, "renderer_zfar": 100.0, "renderer_type": "soft_mesh", "renderer_texture_type": "texture_uv", "renderer_normalized_alphas": False, "deca_path": "DECA", "rome_data_dir": "rome/data", "autoenc_cat_alphas": False, "autoenc_align_inputs": False, "autoenc_use_warp": False, "autoenc_num_channels": 64, "autoenc_max_channels": 512, "autoenc_num_groups": 4, "autoenc_num_bottleneck_groups": 0, "autoenc_num_blocks": 2, "autoenc_num_layers": 4, "autoenc_block_type": "bottleneck", "neural_texture_channels": 8, "num_harmonic_encoding_funcs": 6, "unet_num_channels": 64, "unet_max_channels": 512, "unet_num_groups": 4, "unet_num_blocks": 1, "unet_num_layers": 2, "unet_block_type": "conv", "unet_skip_connection_type": "cat", "unet_use_normals_cond": True, "unet_use_vertex_cond": False, "unet_use_uvs_cond": False, "unet_pred_mask": False, "use_separate_seg_unet": True, "norm_layer_type": "gn", "activation_type": "relu", "conv_layer_type": "ws_conv", "deform_norm_layer_type": "gn", "deform_activation_type": "relu", "deform_conv_layer_type": "ws_conv", "unet_seg_weight": 0.0, "unet_seg_type": "bce_with_logits", "deform_face_tightness": 0.0001, "use_whole_segmentation": False, "mask_hair_for_neck": False, "use_hair_from_avatar": False, "use_scalp_deforms": True, "use_neck_deforms": True, "use_basis_deformer": False, "use_unet_deformer": True, "pretrained_encoder_basis_path": "", "pretrained_vertex_basis_path": "", "num_basis": 50, "basis_init": "pca", "num_vertex": 5023, "train_basis": True, "path_to_deca": "DECA", "path_to_linear_hair_model": "data/linear_hair.pth", # N/A "path_to_mobile_model": "data/disp_model.pth", # N/A "n_scalp": 60, "use_distill": False, "use_mobile_version": False, "deformer_path": "data/rome.pth", "output_unet_deformer_feats": 32, "use_deca_details": False, "use_flametex": False, "upsample_type": "nearest", "num_frequencies": 6, "deform_face_scale_coef": 0.0, "device": "cuda" }) # download FLAME and DECA pretrained generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl') deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar') with open(generic_model_path, 'rb') as f: ss = pickle.load(f, encoding='latin1') with open('./DECA/data/generic_model.pkl', 'wb') as out: pickle.dump(ss, out) with open(deca_model_path, "rb") as input: with open('./DECA/data/deca_model.tar', "wb") as out: for line in input: out.write(line) # load ROME inference model infer = Infer(args) def image_inference( source_img: gr.inputs.Image = None, driver_img: gr.inputs.Image = None ): out = infer.evaluate(source_img, driver_img, crop_center=False) res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(), out['source_information']['data_dict']['target_img'][0].cpu(), out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2)) return res[..., ::-1] def extract_frames( driver_vid: gr.inputs.Video = None ): image_frames = [] vid = cv2.VideoCapture(driver_vid) # path to mp4 while True: success, img = vid.read() if not success: break img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(img) image_frames.append(pil_img) return image_frames def video_inference( source_img: gr.inputs.Image = None, driver_vid: gr.inputs.Video = None ): image_frames = extract_frames(driver_vid) resulted_imgs = defaultdict(list) mask_hard_threshold = 0.5 N = len(image_frames) for i in range(0, N, 4): # frame limits new_out = infer.evaluate(source_img, image_frames[i]) mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float() mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255) render = new_out['pred_target_img'].cpu() * (mask_pred) + (1 - mask_pred) normals = process_black_shape(((new_out['pred_target_normal'][0].cpu() + 1) / 2 * mask_pred + (1 - mask_pred) ) ) normals[normals==0.5]=1. resulted_imgs['res_normal'].append(tensor2image(normals)) resulted_imgs['res_mesh_images'].append(tensor2image(new_out['pred_target_shape_img'][0])) resulted_imgs['res_renders'].append(tensor2image(render[0])) video = np.array(resulted_imgs['res_renders']) fig = plt.figure() im = plt.imshow(video[0,:,:,::-1]) plt.axis('off') plt.close() # this is required to not display the generated image def init(): im.set_data(video[0,:,:,::-1]) def animate(i): im.set_data(video[i,:,:,::-1]) return im anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=30) anim.save("avatar.gif", dpi=300, writer = animation.PillowWriter(fps=24)) return "avatar.gif" description = """
Create a personal avatar from just a single image using ROME.
Paper | Project Page | Github
[The] system creates realistic mesh-based avatars from a single source photo. These avatars are rigged, i.e., they can be driven by the animation parameters from a different driving frame.
""" with gr.Blocks() as demo: gr.Markdown("# **ROME: Realistic one-shot mesh-based head avatars
**") gr.HTML(value="") gr.Markdown(description) gr.Markdown(quote) with gr.Tab("Image Inference"): with gr.Row(): source_img = gr.Image(type="pil", label="Source image", show_label=True) driver_img = gr.Image(type="pil", label="Driver image", show_label=True) image_output = gr.Image(label="Rendered avatar") image_button = gr.Button("Predict") with gr.Tab("Video Inference"): with gr.Row(): source_img2 = gr.Image(type="pil", label="Source image", show_label=True) driver_vid = gr.Video(label="Driver video", source="upload") video_output = gr.Image(label="Rendered GIF avatar") video_button = gr.Button("Predict") with gr.Tab("Webcam Inference"): with gr.Row(): source_img3 = gr.Image(type="pil", label="Source image", show_label=True) driver_cam = gr.Video(label="Driver video", source="webcam") cam_output = gr.Image(label="Rendered GIF avatar") cam_button = gr.Button("Predict") gr.Examples( examples=[ ["./examples/lincoln.jpg", "./examples/taras2.jpg"], ["./examples/lincoln.jpg", "./examples/taras1.jpg"] ], inputs=[source_img, driver_img], outputs=[image_output], fn=image_inference, cache_examples=True ) image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output) video_button.click(video_inference, inputs=[source_img2, driver_vid], outputs=video_output) cam_button.click(video_inference, inputs=[source_img3, driver_cam], outputs=cam_output) demo.launch()