ChatAnything / utils.py
ermu2001
init
08720f3
import os
from PIL import Image
import random
import shutil
import datetime
import torchvision.transforms.functional as f
import torch
from typing import Optional, Tuple
from threading import Lock
from langchain import ConversationChain
from chat_anything.tts_talker.tts_edge import TTSTalker
from chat_anything.sad_talker.sad_talker import SadTalker
from chat_anything.chatbot.chat import load_chain
from chat_anything.chatbot.select import model_selection_chain
from chat_anything.chatbot.voice_select import voice_selection_chain
import gradio as gr
TALKING_HEAD_WIDTH = "350"
sadtalker_checkpoint_path = "MODELS/SadTalker"
config_path = "chat_anything/sad_talker/config"
class ChatWrapper:
def __init__(self):
self.lock = Lock()
self.sad_talker = SadTalker(
sadtalker_checkpoint_path, config_path, lazy_load=True)
def __call__(
self,
api_key: str,
inp: str,
history: Optional[Tuple[str, str]],
chain: Optional[ConversationChain],
speak_text: bool, talking_head: bool,
uid: str,
talker : None,
fullbody : str,
):
"""Execute the chat functionality."""
self.lock.acquire()
if chain is None:
history.append((inp, "Please register with your API key first!"))
else:
try:
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
print("inp: " + inp)
print("speak_text: ", speak_text)
print("talking_head: ", talking_head)
history = history or []
# If chain is None, that is because no API key was provided.
output = "Please paste your OpenAI key from openai.com to use this app. " + \
str(datetime.datetime.now())
output = chain.predict(input=inp).strip()
output = output.replace("\n", "\n\n")
text_to_display = output
# #预定义一个talker
# talker = MaleEn()
history.append((inp, text_to_display))
html_video, temp_file, html_audio, temp_aud_file = None, None, None, None
if speak_text:
if talking_head:
html_video, temp_file = self.do_html_video_speak(
talker, output, fullbody, uid)
else:
html_audio, temp_aud_file = self.do_html_audio_speak(
talker, output,uid)
else:
if talking_head:
temp_file = os.path.join('tmp', uid, 'videos')
html_video = create_html_video(
temp_file, TALKING_HEAD_WIDTH)
else:
pass
except Exception as e:
raise e
finally:
self.lock.release()
return history, history, html_video, temp_file, html_audio, temp_aud_file, ""
def do_html_audio_speak(self,talker, words_to_speak, uid):
audio_path = os.path.join('tmp', uid, 'audios')
print('uid:', uid, ":", words_to_speak)
audo_file_path = talker.test(text=words_to_speak, audio_path=audio_path)
html_audio = '<pre>no audio</pre>'
try:
temp_aud_file = gr.File(audo_file_path)
print("audio-----------------------------------------------------success")
temp_aud_file_url = "/file=" + temp_aud_file.value['name']
html_audio = f'<audio autoplay><source src={temp_aud_file_url} type="audio/mp3"></audio>'
except IOError as error:
# Could not write to file, exit gracefully
print(error)
return None, None
return html_audio, audo_file_path
def do_html_video_speak(self,talker,words_to_speak,fullbody, uid):
if fullbody:
# preprocess='somthing'
preprocess='full'
else:
preprocess='crop'
print("success")
video_path = os.path.join('tmp', uid, 'videos')
if not os.path.exists(video_path):
os.makedirs(video_path)
video_file_path = os.path.join(video_path, 'tempfile.mp4')
_, audio_path = self.do_html_audio_speak(
talker,words_to_speak,uid)
face_file_path = os.path.join('tmp', uid, 'images', 'test.jpg')
video = self.sad_talker.test(face_file_path, audio_path,preprocess, uid=uid) #video_file_path
print("---------------------------------------------------------success")
print(f"moving {video} -> {video_file_path}")
shutil.move(video, video_file_path)
return video_file_path, video_file_path
def generate_init_face_video(self,class_concept="clock", llm=None,uid=None,fullbody=None, ref_image=None, seed=None):
"""
"""
print('generate concept of', class_concept)
print("=================================================")
print('fullbody:', fullbody)
print('uid:', uid)
print("==================================================")
chain, memory, personality_text = load_chain(llm, class_concept)
model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down
# model_conf, selected_model = model_selection_chain(llm, personality_text, conf_file='resources/models_personality.yaml') # use class concept to choose a generating model, otherwise crack down
voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml')
# added for safe face generation
print('generate concept of', class_concept)
augment_word_list = ["Female ", "female ", "beautiful ", "small ", "cute "]
first_sentence = "Hello, how are you doing today?"
voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml')
talker = TTSTalker(selected_voice=selected_voice, gender=voice_conf['gender'], language=voice_conf['language'])
model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down
retry_cnt = 4
if ref_image is None:
face_files = os.listdir(FACE_DIR)
face_img_path = os.path.join(FACE_DIR, random.choice(face_files))
ref_image = Image.open(face_img_path)
print('loading face generating model')
anything_facemaker = load_face_generator(
model_dir=model_conf['model_dir'],
lora_path=model_conf['lora_path'],
prompt_template=model_conf['prompt_template'],
negative_prompt=model_conf['negative_prompt'],
)
retry_cnt = 0
has_face = anything_facemaker.has_face(ref_image)
init_strength = 1.0 if has_face else 0.85
strength_retry_step = -0.04 if has_face else 0.04
while retry_cnt < 8:
try:
generate_face_image(
anything_facemaker,
class_concept,
ref_image,
uid=uid,
strength=init_strength if (retry_cnt==0 and has_face) else init_strength + retry_cnt * strength_retry_step,
controlnet_conditioning_scale=0.5 if retry_cnt == 8 else 0.3,
seed=seed,
)
self.do_html_video_speak(talker, first_sentence, fullbody, uid=uid)
video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4')
htm_video = create_html_video(
video_file_path, TALKING_HEAD_WIDTH)
break
except Exception as e:
retry_cnt += 1
class_concept = random.choice(augment_word_list) + class_concept
print(e)
# end of repeat block
return chain, memory, htm_video, talker
def update_talking_head(self, widget, uid, state):
print("success----------------")
if widget:
state = widget
temp_file = os.path.join('tmp', uid, 'videos')
video_html_talking_head = create_html_video(
temp_file, TALKING_HEAD_WIDTH)
return state, video_html_talking_head
else:
return None, "<pre></pre>"
def reset_memory(history, memory):
memory.clear()
history = []
return history, history, memory
def create_html_video(file_name, width):
return file_name
def create_html_audio(file_name):
if os.path.exists(file_name):
tmp_audio_file = gr.File(file_name, visible=False)
tmp_aud_file_url = "/file=" + tmp_audio_file.value['name']
html_audio = f'<audio><source src={tmp_aud_file_url} type="audio/mp3"></audio>'
del tmp_aud_file_url
else:
html_audio = f''
return html_audio
def update_foo(widget, state):
if widget:
state = widget
return state
# Pertains to question answering functionality
def update_use_embeddings(widget, state):
if widget:
state = widget
return state
# This is the code for image generating.
def load_face_generator(model_dir, lora_path, prompt_template, negative_prompt):
from chat_anything.face_generator.long_prompt_control_generator import LongPromptControlGenerator
# # using local
model_zoo = "MODELS"
face_control_dir = os.path.join(
model_zoo, "Face-Landmark-ControlNet", "models_for_diffusers")
face_detect_path = os.path.join(
model_zoo, "SadTalker", "shape_predictor_68_face_landmarks.dat")
# use remote, hugginface auto-download.
# use your model path, has to be a model derived from stable diffusion v1-5
anything_facemaker = LongPromptControlGenerator(
model_dir=model_dir,
lora_path=lora_path,
prompt_template=prompt_template,
negative_prompt=negative_prompt,
face_control_dir=face_control_dir,
face_detect_path=face_detect_path,
)
anything_facemaker.load_model(safety_checker=None)
return anything_facemaker
FACE_DIR="resources/images/faces"
def generate_face_image(
anything_facemaker,
class_concept,
face_img_pil,
uid=None,
controlnet_conditioning_scale=1.0,
strength=0.95,
seed=42,
):
face_img_pil = f.center_crop(
f.resize(face_img_pil, 512), 512).convert('RGB')
prompt = anything_facemaker.prompt_template.format(class_concept)
# # There are four ways to generate a image by now.
# pure_generate = anything_facemaker.generate(prompt=prompt, image=face_img_pil, do_inversion=False)
# inversion = anything_facemaker.generate(prompt=prompt, image=face_img_pil, strength=strength, do_inversion=True)
print('USING SEED:', seed)
generator = torch.Generator(device=anything_facemaker.face_control_pipe.device)
generator.manual_seed(seed)
if strength is None:
pure_control = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=False,
controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator)
init_face_pil = pure_control
else:
control_inversion = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=True,
strength=strength,
controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator)
init_face_pil = control_inversion
print('succeeded generating face image')
face_path = os.path.join('tmp', uid, 'images')
if not os.path.exists(face_path):
os.makedirs(face_path)
# TODO: reproduce the images for return, shouldn't use the filesystem
face_file_path = os.path.join(face_path, 'test.jpg')
init_face_pil.save(face_file_path)
return init_face_pil