File size: 5,134 Bytes
a22eb82
 
 
 
 
 
 
 
 
8a44c18
a22eb82
03836e1
 
 
 
 
 
a22eb82
 
 
 
 
 
 
 
 
 
61a3d7c
8a44c18
 
a22eb82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4c1fff
a22eb82
8a44c18
a22eb82
 
 
 
 
 
 
 
 
 
 
 
03836e1
 
 
 
 
 
 
a22eb82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03836e1
61a3d7c
 
a22eb82
 
03836e1
3d20599
 
8a44c18
c17e787
3d20599
61a3d7c
 
 
 
 
03836e1
 
a22eb82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from time import gmtime, strftime
import os, sys, shutil
from argparse import ArgumentParser
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff  
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
import uuid

from pydub import AudioSegment

def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
    mp3_file = AudioSegment.from_file(file=mp3_filename)
    mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")

from modules.text2speech import text2speech

class SadTalker():

    def __init__(self, checkpoint_path='checkpoints'):

        if torch.cuda.is_available() :
            device = "cuda"
        else:
            device = "cpu"
        
        # current_code_path = sys.argv[0]
        # modules_path = os.path.split(current_code_path)[0]

        current_root_path = './'

        os.environ['TORCH_HOME']=os.path.join(current_root_path, 'checkpoints')

        path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat')
        path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
        dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting')
        wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')

        audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
        audio2pose_yaml_path = os.path.join(current_root_path, 'config', 'auido2pose.yaml')
    
        audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
        audio2exp_yaml_path = os.path.join(current_root_path, 'config', 'auido2exp.yaml')

        free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar')
        mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar')
        facerender_yaml_path = os.path.join(current_root_path, 'config', 'facerender.yaml')

        #init model
        print(path_of_lm_croper)
        self.preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)

        print(audio2pose_checkpoint)
        self.audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path, 
                                audio2exp_checkpoint, audio2exp_yaml_path, wav2lip_checkpoint, device)
        print(free_view_checkpoint)
        self.animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint, 
                                            facerender_yaml_path, device)
        self.device = device

    def test(self, source_image, driven_audio, still_mode, use_enhancer, result_dir='./'):

        time_tag =  str(uuid.uuid4()) # strftime("%Y_%m_%d_%H.%M.%S")
        save_dir = os.path.join(result_dir, time_tag)
        os.makedirs(save_dir, exist_ok=True)

        input_dir = os.path.join(save_dir, 'input')
        os.makedirs(input_dir, exist_ok=True)

        print(source_image)
        pic_path = os.path.join(input_dir, os.path.basename(source_image)) 
        shutil.move(source_image, input_dir)

        if os.path.isfile(driven_audio):
            audio_path = os.path.join(input_dir, os.path.basename(driven_audio))  

            #### mp3 to wav
            if '.mp3' in audio_path:
                mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
                audio_path = audio_path.replace('.mp3', '.wav')
            else:
                shutil.move(driven_audio, input_dir)
        else:
            text2speech


        os.makedirs(save_dir, exist_ok=True)
        pose_style = 0
        #crop image and extract 3dmm from image
        first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
        os.makedirs(first_frame_dir, exist_ok=True)
        first_coeff_path, crop_pic_path = self.preprocess_model.generate(pic_path, first_frame_dir)
        if first_coeff_path is None:
            raise AttributeError("No face is detected")

        #audio2ceoff
        batch = get_data(first_coeff_path, audio_path, self.device)
        coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
        #coeff2video
        batch_size = 4
        data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode)
        self.animate_from_coeff.generate(data, save_dir, enhancer='gfpgan' if use_enhancer else None)
        video_name = data['video_name']
        print(f'The generated video is named {video_name} in {save_dir}')

        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        import gc; gc.collect() 
        
        if use_enhancer:
            return os.path.join(save_dir, video_name+'_enhanced.mp4'), os.path.join(save_dir, video_name+'_enhanced.mp4')

        else:
            return os.path.join(save_dir, video_name+'.mp4'), os.path.join(save_dir, video_name+'.mp4')