import sys import torch import gradio as gr import pickle from easydict import EasyDict as edict from huggingface_hub import hf_hub_download sys.path.append("./rome/") sys.path.append('./DECA') from rome.infer import Infer from rome.src.utils.processing import process_black_shape, tensor2image # loading models ---- create model repo default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt') default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth') # parser configurations args = edict({ "save_dir": ".", "save_render": True, "model_checkpoint": default_model_path, "modnet_path": default_modnet_path, "random_seed": 0, "debug": False, "verbose": False, "model_image_size": 256, "align_source": True, "align_target": False, "align_scale": 1.25, "use_mesh_deformations": False, "subdivide_mesh": False, "renderer_sigma": 1e-08, "renderer_zfar": 100.0, "renderer_type": "soft_mesh", "renderer_texture_type": "texture_uv", "renderer_normalized_alphas": False, "deca_path": "DECA", "rome_data_dir": "rome/data", "autoenc_cat_alphas": False, "autoenc_align_inputs": False, "autoenc_use_warp": False, "autoenc_num_channels": 64, "autoenc_max_channels": 512, "autoenc_num_groups": 4, "autoenc_num_bottleneck_groups": 0, "autoenc_num_blocks": 2, "autoenc_num_layers": 4, "autoenc_block_type": "bottleneck", "neural_texture_channels": 8, "num_harmonic_encoding_funcs": 6, "unet_num_channels": 64, "unet_max_channels": 512, "unet_num_groups": 4, "unet_num_blocks": 1, "unet_num_layers": 2, "unet_block_type": "conv", "unet_skip_connection_type": "cat", "unet_use_normals_cond": True, "unet_use_vertex_cond": False, "unet_use_uvs_cond": False, "unet_pred_mask": False, "use_separate_seg_unet": True, "norm_layer_type": "gn", "activation_type": "relu", "conv_layer_type": "ws_conv", "deform_norm_layer_type": "gn", "deform_activation_type": "relu", "deform_conv_layer_type": "ws_conv", "unet_seg_weight": 0.0, "unet_seg_type": "bce_with_logits", "deform_face_tightness": 0.0001, "use_whole_segmentation": False, "mask_hair_for_neck": False, "use_hair_from_avatar": False, "use_scalp_deforms": True, "use_neck_deforms": True, "use_basis_deformer": False, "use_unet_deformer": True, "pretrained_encoder_basis_path": "", "pretrained_vertex_basis_path": "", "num_basis": 50, "basis_init": "pca", "num_vertex": 5023, "train_basis": True, "path_to_deca": "DECA", "path_to_linear_hair_model": "data/linear_hair.pth", # N/A "path_to_mobile_model": "data/disp_model.pth", # N/A "n_scalp": 60, "use_distill": False, "use_mobile_version": False, "deformer_path": "data/rome.pth", "output_unet_deformer_feats": 32, "use_deca_details": False, "use_flametex": False, "upsample_type": "nearest", "num_frequencies": 6, "deform_face_scale_coef": 0.0, "device": "cpu" }) # download FLAME and DECA pretrained generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl') deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar') with open(generic_model_path, 'rb') as f: ss = pickle.load(f, encoding='latin1') with open('./DECA/data/generic_model.pkl', 'wb') as out: pickle.dump(ss, out) with open(deca_model_path, "rb") as input: with open('./DECA/data/deca_model.tar', "wb") as out: for line in input: out.write(line) # load ROME inference model infer = Infer(args) def image_inference( source_img: gr.inputs.Image = None, driver_img: gr.inputs.Image = None ): out = infer.evaluate(source_img, driver_img, crop_center=False) res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(), out['source_information']['data_dict']['target_img'][0].cpu(), out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2)) return res[..., ::-1] def video_inference(): pass with gr.Blocks() as demo: gr.Markdown("# **
ROME: Realistic one-shot mesh-based head avatars
**") with gr.Tab("Image Inference"): with gr.Row(): source_img = gr.Image(type="pil") driver_img = gr.Image(type="pil") image_output = gr.Image() image_button = gr.Button("Predict") with gr.Tab("Video Inference"): video_inputs = [gr.Video(), gr.Image()] pass gr.Examples( examples=[ ["./examples/lincoln.jpg"], ["./examples/lincoln.jpg"] ], inputs=[source_img, driver_img], fn=image_inference, cache_examples=True ) image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output) demo.launch()