ermu2001 commited on
Commit
08720f3
·
1 Parent(s): 195eeff
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +16 -0
  2. Dockerfile +27 -0
  3. README.md +138 -0
  4. app.py +239 -0
  5. chat_anything/azure_utils.py +155 -0
  6. chat_anything/chatbot/__init__.py +0 -0
  7. chat_anything/chatbot/chat.py +72 -0
  8. chat_anything/chatbot/model_select.py +60 -0
  9. chat_anything/chatbot/personality.py +59 -0
  10. chat_anything/chatbot/select.py +63 -0
  11. chat_anything/chatbot/voice_select.py +119 -0
  12. chat_anything/face_generator/__init__.py +0 -0
  13. chat_anything/face_generator/long_prompt_control_generator.py +104 -0
  14. chat_anything/face_generator/long_prompt_generator.py +82 -0
  15. chat_anything/face_generator/pipelines/lpw_stable_diffusion.py +1471 -0
  16. chat_anything/face_generator/utils/generate.py +45 -0
  17. chat_anything/polly_utils.py +635 -0
  18. chat_anything/sad_talker/__init__.py +0 -0
  19. chat_anything/sad_talker/audio2exp_models/audio2exp.py +41 -0
  20. chat_anything/sad_talker/audio2exp_models/networks.py +74 -0
  21. chat_anything/sad_talker/audio2pose_models/audio2pose.py +94 -0
  22. chat_anything/sad_talker/audio2pose_models/audio_encoder.py +64 -0
  23. chat_anything/sad_talker/audio2pose_models/cvae.py +149 -0
  24. chat_anything/sad_talker/audio2pose_models/discriminator.py +76 -0
  25. chat_anything/sad_talker/audio2pose_models/networks.py +140 -0
  26. chat_anything/sad_talker/audio2pose_models/res_unet.py +65 -0
  27. chat_anything/sad_talker/config/auido2exp.yaml +58 -0
  28. chat_anything/sad_talker/config/auido2pose.yaml +49 -0
  29. chat_anything/sad_talker/config/facerender.yaml +45 -0
  30. chat_anything/sad_talker/config/facerender_still.yaml +45 -0
  31. chat_anything/sad_talker/config/similarity_Lm3D_all.mat +0 -0
  32. chat_anything/sad_talker/face3d/data/__init__.py +116 -0
  33. chat_anything/sad_talker/face3d/data/base_dataset.py +125 -0
  34. chat_anything/sad_talker/face3d/data/flist_dataset.py +125 -0
  35. chat_anything/sad_talker/face3d/data/image_folder.py +66 -0
  36. chat_anything/sad_talker/face3d/data/template_dataset.py +75 -0
  37. chat_anything/sad_talker/face3d/extract_kp_videos.py +108 -0
  38. chat_anything/sad_talker/face3d/extract_kp_videos_safe.py +162 -0
  39. chat_anything/sad_talker/face3d/models/__init__.py +67 -0
  40. chat_anything/sad_talker/face3d/models/arcface_torch/README.md +164 -0
  41. chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py +25 -0
  42. chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
  43. chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
  44. chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
  45. chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py +23 -0
  46. chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
  47. chat_anything/sad_talker/face3d/models/arcface_torch/configs/__init__.py +0 -0
  48. chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py +56 -0
  49. chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
  50. chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **__pycache__/
2
+
3
+ MODELS
4
+ third_party
5
+ tmp
6
+ results
7
+ chat_anything/tts_vits/
8
+ vits_results
9
+ test
10
+ resources/models.yaml
11
+
12
+ # others
13
+ GFPGANv1.4.pth
14
+ gfpgan
15
+ GFPGAN
16
+ .gitattributes
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
2
+
3
+ # FROM python:3.9
4
+
5
+ # WORKDIR /code
6
+
7
+ # COPY ./requirements.txt /code/requirements.txt
8
+
9
+ # RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
10
+
11
+ # for open cv
12
+ RUN apt-get update && apt-get install libgl1 -y
13
+
14
+ RUN useradd -m -u 1000 user
15
+
16
+ USER user
17
+
18
+ ENV HOME=/home/user \
19
+ PATH=/home/user/.local/bin:$PATH
20
+
21
+ WORKDIR $HOME/ChatAnything
22
+
23
+ COPY --chown=user . $HOME/ChatAnything
24
+
25
+ RUN pip install -r requirements.txt
26
+
27
+ CMD python app.py
README.md CHANGED
@@ -10,3 +10,141 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+ # ChatAnything: Facetime Chat with LLM-Enhanced Personas
14
+
15
+ **Yilin Zhao\*, Shanghua Gao\*, Daquan Zhou\*, Xinbin Yuan\*, Zhijie Lin, Qibin Hou, Jiashi Feng**
16
+
17
+
18
+
19
+ > What will it be like to Facetime any imaginary concepts?
20
+ To animate anything, we integrated current open-source models at hand for an animation application for interactive AI-Agent chatting usage.
21
+ >
22
+ > To start with, take a look at these incredible faces generated with open-source Civitai models that are to be animated.
23
+ <img src="./resources/readme/show.png" alt="drawing" width="784"/>
24
+ <!-- ![faces](./resources/readme/show.png) -->
25
+
26
+ Here we provide you with ChatAnything. A simple pipeline Enhanced with currently limitless Large Language Models, yielding imaginary Facetime chats with intented visual appearance!
27
+
28
+ Remember, the repo and application are totally based on pre-trained deep learning methods and haven't included any training yet. We give all the credit to the open-source community (shout out to you). For detail of the pipeline, see our technical report (TODO: link here)
29
+ ## Release & Features & Future Plans
30
+
31
+ - [ ] Fine-tune face rendering module.
32
+ - [ ] Better TTS module & voice render module.
33
+ - [ ] Adding Open-source Language Models.
34
+ - [x] Initial release
35
+ - Facetime Animation.
36
+ - Multiple model choices for initial frame generation.
37
+ - Multiple choices for voices.
38
+ # Install & Run
39
+ Just follow the instructions. Every thing would be simple (hopefully). Reach out if you met with any problems!
40
+ ### Install
41
+ first, install the virtual environment.
42
+ ```
43
+ conda env create -f environment.yaml
44
+
45
+ # then install
46
+ conda env update --name chatanything --file environment.yaml
47
+ ```
48
+
49
+ The Pipeline integrated Open-Source Models. All Models are to be found online(see [Acknowledgement](#acknowledgement)). We put some important models together on huggingface remotes just to make life easier. Prepare them for the first run with this Python script [prepare_models.py](./python_scripts/prepare_models.py):
50
+ ```
51
+ # prepare the local models
52
+ python python_scripts/prepare_models.py
53
+
54
+ ```
55
+
56
+ ### Building Docker
57
+ Try build a docker if you find it easier. This part is not fully tested. If you find a anything wrong, feel free to contribute~
58
+ ```
59
+ docker build --network=host -t chatanything .
60
+ # docker run -dp 127.0.0.1:8901:8901 chatanything
61
+ docker run -p 127.0.0.1:8901:8901 -it --gpus all chatanything
62
+ docker run -it --gpus all chatanything bash
63
+ ```
64
+
65
+ ### Run
66
+ specify a port for the gradio application to run on and set off!
67
+ ```
68
+ PORT=8809 python app.py $PORT
69
+ ```
70
+
71
+ # Configuring: From User Input Concept to Appearance & Voice
72
+ The first step of the pipeline is to generate a image for SadTalker and at the same time set up the Text to Sound Module for voice chat.
73
+
74
+ The pipeline would query a powerful LLM (ChatGPT) for the selection in a zero-shot multi-choice selection format.
75
+ Three Questions are asked upon the initial of every conversation(init frame generation):
76
+ 1. Provide a imagen personality for the user input concept.
77
+ 2. Select a Generative model for the init frame generation.
78
+ 3. Select a Text To Sound Voice(Model) for the character base on the personality.
79
+
80
+ We have constructed the model selection to be extendable. Add your ideal model with just a few lines of Configuring! The rest of this section would breifly introduce the steps to add a init-frame generator/language voice.
81
+
82
+ ### Image Generator
83
+ Configure the models in the [Model Config](./resources/models.yaml). This Config acts as the memory (or an image-generating tool pool) for the LLM.
84
+
85
+ The prompt sets up this selection process. Each sub field of the "models" would turn into an option in the multiple-choice question.
86
+ the "**desc**" field of each element is what the Language Model would see. The key is not provided to the LM as it would sometimes mislead it.
87
+ the others are used for the image generation as listed:
88
+ 1. model_dir: the repo-path for diffusers package. As the pretrained Face-landmark ControlNet is based on stable-diffusion-v1-5, we currently only supports the derivatives of it.
89
+ 2. lora_path: LoRA derivatives are powerful, try a LoRA model also for better stylization. Should directly point to the parameters binary file.
90
+ 3. prompt_template & negative_prompt: this is used for prompting the text-to-image diffusion model. Find a ideal prompt for your model and stick with it. A "{}" should be in the prompt template for inserting the user input concept.
91
+
92
+ Here are some **Tips** for configuring you own model.
93
+ 1. Provide the LLM with a simple description of the generative model. It is worth noting that the description needs to be concise and accurate for a correct selection.
94
+ 2. Set the model_dir to a local directory of diffusers stable-diffusion-v1-5 derivatives. Also, you can provide a repo-id on the huggingface hub model space. The model would be downloaded when first chosen, wait for it.
95
+ 3. To better utilize the resources from the community, we also add in support of the LoRA features. To add the LoRA module, you would need to give the path to the parameter files.
96
+
97
+ 4. Carefully write the prompt template and negative prompt. These which affect the initial face generation a lot. Be aware that the prompt template should contain only one pair of "{}" to insert the concept that users wrote on the application webpage. We support the Stable-Diffusion-Webui prompt style as implemented by diffusers, feel free to copy the prompt from Civitai for better prompting the generation and put in the "{}" to the original prompt for ChatAnything!
98
+
99
+ Again, this model's config acts as an extended tool pool for the LM, the application would drive the LM to choose from this config and use the chosen model to generate. Sometimes the LM fails to choose the correct model or choosing any available model, this would cause the Chatanything app to fail on a generation.
100
+
101
+ Notice we currently support ONLY stable-diffusion-v1.5 derivatives (Sdxl Pipelines are under consideration, however not yet implemented as we lack a face-landmark ControlNet for it. Reach out if you're interested in training one!)
102
+
103
+ ### Voice TTS
104
+ We are using the edge_tts package for text-to-speech support. The voice selection and [voice configuration file](./resources/voices_edge.yaml) is constructed similarly to the Image generation model selection, except now the LM is supposed to choose the voice base on the personality description given by itself earlier. "**gender**" and "**language**" field corresponds to edge_tts.
105
+
106
+ # On-going tasks.
107
+ ### Customized Voice.
108
+ There is a Voice Changer TextToSpeach-SpeachVoiceConversion Pipeline app, which ensures a better customized voice. We are trying to leverage its TTS functionality.
109
+
110
+ Reach out if you want to add a voice of your own or your hero!
111
+
112
+ Here are the possible steps for
113
+ You would need to change a little bit in the code first:
114
+ 1. Alter this [code](./utils.py#14) to import a TTSTalker from chat_anything/tts_talker/tts_voicechanger.py.
115
+ 2. switch the config to another one, change [code](./utils.py#14) "resources/voices_edge.yaml" -> "resources/voices_voicechanger.yaml"
116
+
117
+ The try running a [Voice Changer](https://huggingface.co/spaces/kevinwang676/Voice-Changer) on your local machine. Simply set up git-lfs and install the repo and run it for the TTS voice service.
118
+ The TTS caller was set to port 7860.
119
+
120
+ make sure the client class is set up with the same port in [here](chat_anything/tts_talker/tts_voicechanger.py#5)
121
+ ```python
122
+ client = Client("http://127.0.0.1:7860/")
123
+ ```
124
+
125
+ # Acknowledgement
126
+ Again, the project hasn't yet included any training. The pipeline is totally based on these incredible awesome packages and pretrained models. Don't hesitate to take a look and explore the amazing open-source generative communities. We love you, guys.
127
+ - [ChatGPT](https://openai.com/chatgpt): GOD
128
+ - [SadTalker](https://github.com/OpenTalker/SadTalker): The Core Animation Module
129
+ - [Face-Landmark-ControlNet](https://huggingface.co/georgefen/Face-Landmark-ControlNet): An Awesome ControlNet with Face landmark using Stable Diffusion 1.5 as base Model.
130
+ - [diffusers](https://github.com/huggingface/diffusers): GOAT of Image Generative Framework🥳.
131
+ - [langchain](https://github.com/langchain-ai/langchain): An Awesome Package for Dealing with LLM.
132
+ - [edge-tts](https://github.com/rany2/edge-tts): An Awesome Package for Text To Sound Solutions.
133
+ - [gradio](https://www.gradio.app/): GOAT😄 Machine Learning based App framework.
134
+ - [Civitai](https://civitai.com/models) and [Huggingface_hub](https://huggingface.co/models): Find your ideal Image Generative Model on Civitai. These Communities are Crazy🥂. Here are Some Fantastic Derivatives of [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5):
135
+ - [Game Icon Institute_mode](https://civitai.com/models/47800?modelVersionId=76533)
136
+ - [dreamshaper](https://civitai.com/models/4384/dreamshaper)
137
+ - [3D_Animation_Diffusion](https://civitai.com/models/118086?modelVersionId=128046)
138
+ - [anything-v5](https://huggingface.co/stablediffusionapi/anything-v5)
139
+
140
+ # Citation
141
+ If you like our pipeline and application, don't hesitate to reach out! Let's work on it and see how far it would go!
142
+ ```bibtex
143
+ @misc{zhao2023ChatAnything,
144
+ title={ChatAnything: Facetime Chat with LLM-Enhanced Personas},
145
+ author={Yilin, Zhao and Shanghua, Gao and Daquan, Zhou and Xinbin, Yuan and Qibin, Hou and Jiashi, Feng},
146
+ publisher={},
147
+ year={2023},
148
+ }
149
+ ```
150
+
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ssl
3
+ import sys
4
+
5
+ import gradio as gr
6
+
7
+ import warnings
8
+ import whisper
9
+ from chat_anything.polly_utils import PollyVoiceData
10
+ from chat_anything.azure_utils import AzureVoiceData
11
+ from chat_anything.chatbot.chat import set_openai_api_key
12
+ from utils import ChatWrapper, update_foo, reset_memory
13
+
14
+ ssl._create_default_https_context = ssl._create_unverified_context
15
+
16
+
17
+ TALKING_HEAD_WIDTH = "350"
18
+
19
+ LOOPING_TALKING_HEAD = "resources/videos/tempfile.mp4"
20
+
21
+ USE_GPT4_DEFAULT = False
22
+ FULLBODY_DEFAULT = False
23
+ POLLY_VOICE_DATA = PollyVoiceData()
24
+ AZURE_VOICE_DATA = AzureVoiceData()
25
+
26
+ # Pertains to WHISPER functionality
27
+ WHISPER_DETECT_LANG = "Detect language"
28
+
29
+ INSTRUCTION_MARKDOWN = """
30
+ # ChatAnything: Facetime Chat with LLM-Enhanced Personas
31
+ ### DEMO INSTRUCTION
32
+ ##### 0. Register
33
+ Input a OpenAI API Key of your own. This would be used to chat with openai-chatgpt. Make sure to disable the key afterwards🥹.
34
+ ##### 1. Generate The init face😀 along with first chat
35
+ Input a Concept in the "Talking object" text box, then click on Generate button. The init face generation and module selection will be performed and used for the rest of this chat. Wait for a while and the video would be produced and played. Write simple concept for generating. The concept will be place on each prompt template for deciding the main concepts.
36
+ ##### 2. Keep on Chatting🤑
37
+ Go on speak with the character. The init face and module selection will not reperform itself, now you are only chatting with the LM, along with the rendering of sadtalker. Hopefully, the API will not impose an excessive charge for this.
38
+
39
+
40
+ ### FEATURES
41
+ ##### 1. Upload a image for control/inversion starting point. Try some none face images and see how it works!
42
+ ##### 2. seeding is provided. However if not providing a input image, there would be a random chosen facial landmark image for generating, which might include some randomness.
43
+ ##### 3. Try out the examples.
44
+ ##### 4. Say something and recorded your voice for a real facetime chat. Whisper will handle your voice, see setting-Whisper STT options.
45
+ ##### 5. Decide whether to use the crop face out option, this will crop out the face from the generated image and render. This is promising for better animation rendering, however sometimes the croped image loses some elementary features of you intended concept.
46
+
47
+ """
48
+
49
+ # UNCOMMENT TO USE WHISPER
50
+ warnings.filterwarnings("ignore")
51
+ WHISPER_MODEL = whisper.load_model("tiny")
52
+ print("WHISPER_MODEL", WHISPER_MODEL)
53
+
54
+
55
+ # UNCOMMENT TO USE WHISPER
56
+ def transcribe(aud_inp, whisper_lang):
57
+ if aud_inp is None:
58
+ return ""
59
+ aud = whisper.load_audio(aud_inp)
60
+ aud = whisper.pad_or_trim(aud)
61
+ mel = whisper.log_mel_spectrogram(aud).to(WHISPER_MODEL.device)
62
+ _, probs = WHISPER_MODEL.detect_language(mel)
63
+ options = whisper.DecodingOptions()
64
+ if whisper_lang != WHISPER_DETECT_LANG:
65
+ whisper_lang_code = POLLY_VOICE_DATA.get_whisper_lang_code(
66
+ whisper_lang)
67
+ options = whisper.DecodingOptions(language=whisper_lang_code)
68
+ result = whisper.decode(WHISPER_MODEL, mel, options)
69
+ print("result.text", result.text)
70
+ result_text = ""
71
+ if result and result.text:
72
+ result_text = result.text
73
+ return result_text
74
+
75
+
76
+ chat = ChatWrapper()
77
+
78
+
79
+ with gr.Blocks() as block:
80
+ llm_state = gr.State()
81
+ history_state = gr.State()
82
+ chain_state = gr.State()
83
+ talker_state = gr.State()
84
+ fullbody_state = gr.State(True)
85
+ speak_text_state = gr.State(True)
86
+ talking_head_state = gr.State(True)
87
+ uid_state = gr.State()
88
+ video_file_path = gr.State()
89
+ audio_file_path = gr.State()
90
+
91
+ memory_state = gr.State()
92
+
93
+
94
+ # Pertains to WHISPER functionality
95
+ whisper_lang_state = gr.State(WHISPER_DETECT_LANG)
96
+ use_gpt4_state = gr.State(USE_GPT4_DEFAULT)
97
+
98
+ with gr.Column():
99
+ with gr.Row():
100
+ gr.Markdown(INSTRUCTION_MARKDOWN)
101
+ with gr.Row():
102
+ openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
103
+ show_label=True, lines=1, type='password', value='', label='OpenAI API key')
104
+ openai_api_key_register = gr.Button(
105
+ value="Register").style(full_width=False)
106
+ uid_textbox = gr.Textbox(show_label=True, value=uid_state, lines=1, label='UID')
107
+ seed = gr.Slider(
108
+ label="Seed",
109
+ minimum=-1,
110
+ maximum=2147483647,
111
+ step=1,
112
+ randomize=True,
113
+ )
114
+
115
+ with gr.Tab("Chat"):
116
+ with gr.Row():
117
+ with gr.Column(scale=1, min_width=TALKING_HEAD_WIDTH, visible=True):
118
+ with gr.Column():
119
+ class_prompt = gr.Textbox(
120
+ 'apple',
121
+ default='apple',
122
+ type="text", label='Talking object'
123
+ )
124
+ init_face_btn = gr.Button(
125
+ value="Generate").style(full_width=False)
126
+
127
+ my_file = gr.File(label="Upload a file",
128
+ type="file", visible=False)
129
+
130
+ # video_html = gr.HTML('')
131
+ video_html = gr.Video(label="Generated Video", autoplay=True)
132
+
133
+ ref_image = gr.Image(
134
+ type="pil",
135
+ interactive=True,
136
+ label="Image: Upload your image.",
137
+ )
138
+ tmp_aud_file = gr.File(
139
+ type="file", visible=False)
140
+ audio_html = gr.HTML('')
141
+ init_face_btn.click(chat.generate_init_face_video, inputs=[class_prompt, llm_state, uid_state,fullbody_state, ref_image, seed],
142
+ outputs=[chain_state, memory_state, video_html,talker_state])
143
+
144
+
145
+ with gr.Column(scale=7):
146
+ chatbot = gr.Chatbot()
147
+
148
+
149
+ message = gr.Textbox(label="What's on your mind??",
150
+ placeholder="What's the answer to life, the universe, and everything?",
151
+ lines=1)
152
+ submit = gr.Button(value="Send", variant="secondary").style(
153
+ full_width=False)
154
+
155
+ audio_comp = gr.Microphone(source="microphone", type="filepath", label="Just say it!",
156
+ interactive=True, streaming=False)
157
+ audio_comp.change(transcribe, inputs=[
158
+ audio_comp, whisper_lang_state], outputs=[message])
159
+
160
+
161
+ with gr.Accordion("General examples", open=False):
162
+ gr.Examples(
163
+ examples=[
164
+ ["cyberpunk godess", "Who are you?", "resources/images/annie.jpg", 393212389],
165
+ ["unbelievable beauty fairy", "Who are you?", "resources/images/lenna.jpg", 222679277],
166
+ ["tree monster", "Who are you?", None],
167
+ ["pineapple monster", "Who are you?", None],
168
+ ["tricky Polaris", "Who are you?", None, 1670155100],
169
+ ["watermelon", "Who are you?", "resources/images/watermelon.jpg", 42],
170
+ ],
171
+ inputs=[class_prompt, message, ref_image, seed],
172
+ )
173
+
174
+ with gr.Tab("Settings"):
175
+ with gr.Tab("General"):
176
+
177
+ talking_head_cb = gr.Checkbox(
178
+ label="Show talking head", value=True)
179
+ talking_head_cb.change(chat.update_talking_head, inputs=[talking_head_cb, uid_state, talking_head_state],
180
+ outputs=[talking_head_state, video_html])
181
+
182
+ use_gpt4_cb = gr.Checkbox(label="Use GPT-4 (experimental) if your OpenAI API has access to it",
183
+ value=USE_GPT4_DEFAULT)
184
+
185
+ fullbody_state = gr.Checkbox(label="Use full body instead of a face.",
186
+ value=True)
187
+
188
+ use_gpt4_cb.change(set_openai_api_key,
189
+ inputs=[openai_api_key_textbox,
190
+ use_gpt4_cb],
191
+ outputs=[llm_state, use_gpt4_state, chatbot, uid_state, video_file_path, audio_file_path])
192
+
193
+ reset_btn = gr.Button(value="Reset chat",
194
+ variant="secondary").style(full_width=False)
195
+ reset_btn.click(reset_memory, inputs=[history_state, memory_state],
196
+ outputs=[chatbot, history_state, memory_state])
197
+
198
+
199
+ with gr.Tab("Whisper STT"):
200
+ whisper_lang_radio = gr.Radio(label="Whisper speech-to-text language:", choices=[
201
+ WHISPER_DETECT_LANG, "Arabic", "Arabic (Gulf)", "Catalan", "Chinese (Cantonese)", "Chinese (Mandarin)",
202
+ "Danish", "Dutch", "English (Australian)", "English (British)", "English (Indian)", "English (New Zealand)",
203
+ "English (South African)", "English (US)", "English (Welsh)", "Finnish", "French", "French (Canadian)",
204
+ "German", "German (Austrian)", "Georgian", "Hindi", "Icelandic", "Indonesian", "Italian", "Japanese",
205
+ "Korean", "Norwegian", "Polish",
206
+ "Portuguese (Brazilian)", "Portuguese (European)", "Romanian", "Russian", "Spanish (European)",
207
+ "Spanish (Mexican)", "Spanish (US)", "Swedish", "Turkish", "Ukrainian", "Welsh"],
208
+ value=WHISPER_DETECT_LANG)
209
+
210
+ whisper_lang_radio.change(update_foo,
211
+ inputs=[whisper_lang_radio,
212
+ whisper_lang_state],
213
+ outputs=[whisper_lang_state])
214
+
215
+ gr.HTML("""
216
+ <p>This application is based on <a href='https://huggingface.co/spaces/JavaFXpert/Chat-GPT-LangChain/'>Chat-GPT-LangChain</a>, <a href='https://github.com/hwchase17/langchain'>LangChain</a>
217
+ </p>""")
218
+
219
+ message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
220
+ speak_text_state, talking_head_state, uid_state,talker_state,fullbody_state],
221
+ outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
222
+
223
+ submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
224
+ speak_text_state, talking_head_state, uid_state,talker_state,fullbody_state],
225
+ outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
226
+
227
+ openai_api_key_register.click(set_openai_api_key,
228
+ inputs=[openai_api_key_textbox,
229
+ use_gpt4_state, chatbot],
230
+ outputs=[llm_state, use_gpt4_state, chatbot, uid_state, video_file_path, audio_file_path])
231
+
232
+ if __name__ == "__main__":
233
+ import sys
234
+ if len(sys.argv) == 1:
235
+ port = 8901
236
+ else:
237
+ port = int(sys.argv[1])
238
+ block.launch(debug=True, server_name="0.0.0.0",
239
+ server_port=port, share=True, enable_queue = True)
chat_anything/azure_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This class stores Azure voice data. Specifically, the class stores several records containing
2
+ # language, lang_code, gender, voice_id and engine. The class also has a method to return the
3
+ # voice_id, lang_code and engine given a language and gender.
4
+
5
+ NEURAL_ENGINE = "neural"
6
+ STANDARD_ENGINE = "standard"
7
+
8
+
9
+ class AzureVoiceData:
10
+ def get_voice(self, language, gender):
11
+ for voice in self.voice_data:
12
+ if voice['language'] == language and voice['gender'] == gender:
13
+ return voice['azure_voice']
14
+ return None
15
+
16
+ def __init__(self):
17
+ self.voice_data = [
18
+ {'language': 'Arabic',
19
+ 'azure_voice': 'ar-EG-ShakirNeural',
20
+ 'gender': 'Male'},
21
+ {'language': 'Arabic (Gulf)',
22
+ 'azure_voice': 'ar-KW-FahedNeural',
23
+ 'gender': 'Male'},
24
+ {'language': 'Catalan',
25
+ 'azure_voice': 'ca-ES-EnricNeural',
26
+ 'gender': 'Male'},
27
+ {'language': 'Chinese (Cantonese)',
28
+ 'azure_voice': 'yue-CN-YunSongNeural',
29
+ 'gender': 'Male'},
30
+ {'language': 'Chinese (Mandarin)',
31
+ 'azure_voice': 'zh-CN-YunxiNeural',
32
+ 'gender': 'Male'},
33
+ {'language': 'Danish',
34
+ 'azure_voice': 'da-DK-JeppeNeural',
35
+ 'gender': 'Male'},
36
+ {'language': 'Dutch',
37
+ 'azure_voice': 'nl-NL-MaartenNeural',
38
+ 'gender': 'Male'},
39
+ {'language': 'English (Australian)',
40
+ 'azure_voice': 'en-AU-KenNeural',
41
+ 'gender': 'Male'},
42
+ {'language': 'English (British)',
43
+ 'azure_voice': 'en-GB-RyanNeural',
44
+ 'gender': 'Male'},
45
+ {'language': 'English (Indian)',
46
+ 'azure_voice': 'en-IN-PrabhatNeural',
47
+ 'gender': 'Male'},
48
+ {'language': 'English (New Zealand)',
49
+ 'azure_voice': 'en-NZ-MitchellNeural',
50
+ 'gender': 'Male'},
51
+ {'language': 'English (South African)',
52
+ 'azure_voice': 'en-ZA-LukeNeural',
53
+ 'gender': 'Male'},
54
+ {'language': 'English (US)',
55
+ 'azure_voice': 'en-US-ChristopherNeural',
56
+ 'gender': 'Male'},
57
+ {'language': 'English (Welsh)',
58
+ 'azure_voice': 'cy-GB-AledNeural',
59
+ 'gender': 'Male'},
60
+ {'language': 'Finnish',
61
+ 'azure_voice': 'fi-FI-HarriNeural',
62
+ 'gender': 'Male'},
63
+ {'language': 'French',
64
+ 'azure_voice': 'fr-FR-HenriNeural',
65
+ 'gender': 'Male'},
66
+ {'language': 'French (Canadian)',
67
+ 'azure_voice': 'fr-CA-AntoineNeural',
68
+ 'gender': 'Male'},
69
+ {'language': 'German',
70
+ 'azure_voice': 'de-DE-KlausNeural',
71
+ 'gender': 'Male'},
72
+ {'language': 'German (Austrian)',
73
+ 'azure_voice': 'de-AT-JonasNeural',
74
+ 'gender': 'Male'},
75
+ {'language': 'Hindi',
76
+ 'azure_voice': 'hi-IN-MadhurNeural',
77
+ 'gender': 'Male'},
78
+ {'language': 'Icelandic',
79
+ 'azure_voice': 'is-IS-GunnarNeural',
80
+ 'gender': 'Male'},
81
+ {'language': 'Italian',
82
+ 'azure_voice': 'it-IT-GianniNeural',
83
+ 'gender': 'Male'},
84
+ {'language': 'Japanese',
85
+ 'azure_voice': 'ja-JP-KeitaNeural',
86
+ 'gender': 'Male'},
87
+ {'language': 'Korean',
88
+ 'azure_voice': 'ko-KR-GookMinNeural',
89
+ 'gender': 'Male'},
90
+ {'language': 'Norwegian',
91
+ 'azure_voice': 'nb-NO-FinnNeural',
92
+ 'gender': 'Male'},
93
+ {'language': 'Polish',
94
+ 'azure_voice': 'pl-PL-MarekNeural',
95
+ 'gender': 'Male'},
96
+ {'language': 'Portuguese (Brazilian)',
97
+ 'azure_voice': 'pt-BR-NicolauNeural',
98
+ 'gender': 'Male'},
99
+ {'language': 'Portuguese (European)',
100
+ 'azure_voice': 'pt-PT-DuarteNeural',
101
+ 'gender': 'Male'},
102
+ {'language': 'Romanian',
103
+ 'azure_voice': 'ro-RO-EmilNeural',
104
+ 'gender': 'Male'},
105
+ {'language': 'Russian',
106
+ 'azure_voice': 'ru-RU-DmitryNeural',
107
+ 'gender': 'Male'},
108
+ {'language': 'Spanish (European)',
109
+ 'azure_voice': 'es-ES-TeoNeural',
110
+ 'gender': 'Male'},
111
+ {'language': 'Spanish (Mexican)',
112
+ 'azure_voice': 'es-MX-LibertoNeural',
113
+ 'gender': 'Male'},
114
+ {'language': 'Spanish (US)',
115
+ 'azure_voice': 'es-US-AlonsoNeural"',
116
+ 'gender': 'Male'},
117
+ {'language': 'Swedish',
118
+ 'azure_voice': 'sv-SE-MattiasNeural',
119
+ 'gender': 'Male'},
120
+ {'language': 'Turkish',
121
+ 'azure_voice': 'tr-TR-AhmetNeural',
122
+ 'gender': 'Male'},
123
+ {'language': 'Welsh',
124
+ 'azure_voice': 'cy-GB-AledNeural',
125
+ 'gender': 'Male'},
126
+ ]
127
+
128
+
129
+ # Run from the command-line
130
+ if __name__ == '__main__':
131
+ azure_voice_data = AzureVoiceData()
132
+
133
+ azure_voice = azure_voice_data.get_voice('English (US)', 'Male')
134
+ print('English (US)', 'Male', azure_voice)
135
+
136
+ azure_voice = azure_voice_data.get_voice('English (US)', 'Female')
137
+ print('English (US)', 'Female', azure_voice)
138
+
139
+ azure_voice = azure_voice_data.get_voice('French', 'Female')
140
+ print('French', 'Female', azure_voice)
141
+
142
+ azure_voice = azure_voice_data.get_voice('French', 'Male')
143
+ print('French', 'Male', azure_voice)
144
+
145
+ azure_voice = azure_voice_data.get_voice('Japanese', 'Female')
146
+ print('Japanese', 'Female', azure_voice)
147
+
148
+ azure_voice = azure_voice_data.get_voice('Japanese', 'Male')
149
+ print('Japanese', 'Male', azure_voice)
150
+
151
+ azure_voice = azure_voice_data.get_voice('Hindi', 'Female')
152
+ print('Hindi', 'Female', azure_voice)
153
+
154
+ azure_voice = azure_voice_data.get_voice('Hindi', 'Male')
155
+ print('Hindi', 'Male', azure_voice)
chat_anything/chatbot/__init__.py ADDED
File without changes
chat_anything/chatbot/chat.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from chat_anything.chatbot.personality import generate_personality_prompt
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain import ConversationChain
5
+ from langchain.chains.conversation.memory import ConversationBufferMemory
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.embeddings.openai import OpenAIEmbeddings
8
+ import os
9
+ import random
10
+ import string
11
+
12
+
13
+ def load_chain(llm, class_concept=None):
14
+ chain = None
15
+ memory = None
16
+ personality_text = None
17
+ print(llm)
18
+ if llm:
19
+ print("class_concept", class_concept)
20
+ if class_concept is None:
21
+ class_concept = 'AI assistant'
22
+ person_template, personality_text = generate_personality_prompt(llm, class_concept)
23
+
24
+ PROMPT_TEMPLATE = PromptTemplate(
25
+ input_variables=["history", "input"],
26
+ template=person_template,
27
+ )
28
+
29
+ chain = ConversationChain(
30
+ prompt=PROMPT_TEMPLATE,
31
+ llm=llm,
32
+ verbose=False,
33
+ memory=ConversationBufferMemory(ai_prefix="You"),
34
+ )
35
+ print("New concept done for ", class_concept)
36
+
37
+ return chain, memory, personality_text
38
+
39
+
40
+
41
+ def set_openai_api_key(api_key, use_gpt4, history=None, max_tokens=1024):
42
+ """Set the api key and return chain.
43
+ If no api_key, then None is returned.
44
+ """
45
+ if api_key and api_key.startswith("sk-") and len(api_key) > 50:
46
+ os.environ["OPENAI_API_KEY"] = api_key
47
+ print("\n\n ++++++++++++++ Setting OpenAI API key ++++++++++++++ \n\n")
48
+ print(str(datetime.datetime.now()) + ": Before OpenAI, OPENAI_API_KEY length: " + str(
49
+ len(os.environ["OPENAI_API_KEY"])))
50
+
51
+ if use_gpt4:
52
+ llm = ChatOpenAI(
53
+ temperature=0, max_tokens=max_tokens, model_name="gpt-4")
54
+ print("Trying to use llm ChatOpenAI with gpt-4")
55
+ else:
56
+ print("Trying to use llm ChatOpenAI with gpt-3.5-turbo")
57
+ llm = ChatOpenAI(temperature=0, max_tokens=max_tokens,
58
+ model_name="gpt-3.5-turbo")
59
+
60
+ print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str(
61
+ len(os.environ["OPENAI_API_KEY"])))
62
+
63
+ print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str(
64
+ len(os.environ["OPENAI_API_KEY"])))
65
+ os.environ["OPENAI_API_KEY"] = ""
66
+ history = history or []
67
+ history.append(['', '[SYSTEM] OPENAI_API_KEY has been set, you can generate your object and talk to it now!'])
68
+ uid = ''.join(random.sample(string.ascii_lowercase + string.ascii_uppercase, 5))
69
+ video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4')
70
+ audio_file_path = os.path.join('tmp', uid, 'audio/tempfile.mp3')
71
+ return llm, use_gpt4, history, uid, video_file_path, audio_file_path
72
+ return None, None, None, None, None, None
chat_anything/chatbot/model_select.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import LLMChain
2
+ from langchain.prompts import PromptTemplate
3
+ from omegaconf import OmegaConf
4
+ import datetime
5
+
6
+ MODEL_SELECTION_PROMPT_TEMPLATE = """
7
+ Select one of the following models based on the given concept.
8
+ You must choose one model name based on the description of each model and the concept!
9
+
10
+ Cencept: {concept}
11
+
12
+ Model name and description: {model_list}
13
+
14
+ Warning: {warning}
15
+
16
+ The avilable model names:
17
+ {model_name_list}
18
+
19
+ Selected model name:
20
+ """
21
+
22
+ def load_model_list():
23
+ models_config = OmegaConf.load('resources/models.yaml')
24
+ models_dict = models_config['models']
25
+ model_name_list_str = ''
26
+ print(models_dict)
27
+ model_list_str = ''
28
+ for key, value in models_dict.items():
29
+ model_list_str+="model name: " +key+', model description: '+value['desc']+'\n'
30
+ model_name_list_str += key + ' '
31
+ model_name_list_str += '\n'
32
+ return model_list_str, models_dict, model_name_list_str
33
+
34
+ def model_selection_chain(llm, class_concept=None):
35
+ chain = None
36
+ memory = None
37
+ if llm:
38
+ print("class_concept", class_concept)
39
+ if class_concept is None:
40
+ class_concept = 'AI assistant'
41
+
42
+
43
+ template = PromptTemplate(
44
+ input_variables=["model_list", "concept", "warning", "model_name_list"],
45
+ template=MODEL_SELECTION_PROMPT_TEMPLATE,
46
+ )
47
+ model_list_str, models_dict, model_name_list_str = load_model_list()
48
+
49
+ personality_chain = LLMChain(
50
+ llm=llm, prompt=template, verbose=True)
51
+ selected_model = None
52
+ while (selected_model is None) or not (selected_model in models_dict):
53
+ if (selected_model is not None) and not (selected_model in models_dict):
54
+ warning_str = '{} is not in Model list! \n'.format(selected_model)
55
+ else:
56
+ warning_str = ''
57
+ selected_model = personality_chain.run({'concept': class_concept, 'model_list':model_list_str, 'warning': warning_str, 'model_name_list': model_name_list_str})
58
+ print("Selected model name: ", selected_model)
59
+
60
+ return models_dict[selected_model]
chat_anything/chatbot/personality.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import LLMChain
2
+ from langchain.prompts import PromptTemplate
3
+
4
+ PERSONALITY_PROMPT_TEMPLATE = """
5
+ You are an excellent scriptwriter. Now you need to provide the characteristics of an {object} and transforms them into personality traits.
6
+ Describe these personalities using the second person, giving names and specific personality descriptions related to the {object}.
7
+ The language of the Personality must be same as {object}!
8
+
9
+ You should do the following steps:
10
+ 1. Based on the object's nature, imagine what kind of personality it could have if it were to come to life. Does it possess a strong sense of responsibility, like a caring caregiver? Is it playful and mischievous, like a curious child? Is it wise and patient, like an ancient sage? Be creative and invent traits that align with the object's essence.
11
+ 2. Remember to infuse emotions and vivid imagery to bring your object's personality to life.
12
+ 3. translate the personality into a second person prompt.
13
+
14
+ Example:
15
+
16
+
17
+ Now give the personality of apple:
18
+
19
+ Personality:
20
+ You an apple Sprite, your name is Apple Buddy.
21
+ You have all the characteristics of the apple. You are a type of fruit that is usually round with smooth skin and comes in various colors such as red, green, and yellow. You have sweet and nutritious flesh with seeds distributed in its core. You are a rich source of vitamins, fiber, and antioxidants, contributing to maintaining a healthy body.
22
+
23
+ You are an optimistic buddy. Always wearing a smile, you spread joy to those around you. Just like the delightful taste of an apple, you bring happiness to everyone.
24
+
25
+ You are resilient at heart, like the skin of an apple, able to withstand life's challenges and difficulties. No matter what obstacles you encounter, you face them bravely without hesitation.
26
+
27
+ You are caring and considerate, akin to the nutrients in an apple. You always pay attention to the needs and happiness of others. Skilled in listening, you willingly offer help and support, making those around you feel warmth and care.
28
+
29
+ You have a strong desire to grow. Like an apple tree needs sunlight and water to flourish, you are continuously learning and improving, becoming a better version of yourself every day.
30
+
31
+ You have a profound love for nature and enjoy living in harmony with it. Strolling in the garden, feeling the fresh air and warm sunlight, is one of your favorite moments.
32
+
33
+ Apple Buddy, you are a unique apple. Your optimism, resilience, care, and eagerness to grow make you an adorable companion to those around you. Your story will lead us into a world full of warmth and goodness.
34
+
35
+ Now give the personality of {object}:
36
+
37
+ Personality:
38
+ """
39
+
40
+
41
+ def generate_personality_prompt(llm, class_concept):
42
+
43
+ PERSONALITY_PROMPT = PromptTemplate(
44
+ input_variables=["object"],
45
+ template=PERSONALITY_PROMPT_TEMPLATE,
46
+ )
47
+ personality_chain = LLMChain(
48
+ llm=llm, prompt=PERSONALITY_PROMPT, verbose=True)
49
+ personality_text = personality_chain.run({'object': class_concept})
50
+ person_prompt = personality_text
51
+
52
+ person_prompt += '''The following is a friendly conversation between a human and you. You need to talk to human based on your personality. If you do not know the answer to a question, you truthfully says you do not know.
53
+ You can use up to 50 words to answer. Make you answer concise and concise!!!!!!!!
54
+ Current conversation:
55
+ {history}
56
+ Human: {input}
57
+ You:
58
+ '''
59
+ return person_prompt, personality_text
chat_anything/chatbot/select.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import LLMChain
2
+ from typing import OrderedDict
3
+ from langchain.prompts import PromptTemplate
4
+ from omegaconf import OmegaConf
5
+ import datetime
6
+
7
+ SELECTION_TEMPLATE = """
8
+ {concept}
9
+
10
+ Model name and description:
11
+ {option_list}
12
+
13
+ Warning: {warning}
14
+
15
+ The avilable Options:
16
+ {choices}
17
+ Answer:
18
+ """
19
+
20
+
21
+ def selection_chain(llm, class_concept, prompt, options):
22
+ chain = None
23
+ memory = None
24
+ if llm:
25
+ print("class_concept", class_concept)
26
+ if class_concept is None:
27
+ class_concept = 'AI assistant'
28
+ prompt_template = prompt + SELECTION_TEMPLATE
29
+ template = PromptTemplate(
30
+ input_variables=["concept", "option_list", "warning", "choices"],
31
+ template=prompt_template,
32
+ )
33
+ chain = LLMChain(
34
+ llm=llm, prompt=template, verbose=True)
35
+ print(options)
36
+ option_list = [
37
+ f"{chr(ord('A') + i)}. {conf['desc']}" for i, conf in enumerate(options.values())
38
+ ]
39
+ option_list = '\n'.join(option_list)
40
+ selected_model = None
41
+
42
+ warning_str = 'Choose from the available Options.'
43
+ choices = ' '.join(chr(ord('A') + i) for i in range(len(options)))
44
+ choice = chain.run({'concept': class_concept, 'option_list':option_list, 'warning': warning_str, 'choices': choices})
45
+ print(f"LLM Responds (First character was used as the choice):{choice}", )
46
+ choice = choice[0]
47
+
48
+ selected_model = list(options.keys())[ord(choice) - ord('A')]
49
+ print("Selected model name: ", selected_model)
50
+
51
+ return selected_model
52
+
53
+ def model_selection_chain(llm, class_concept=None, conf_file='resources/models_personality.yaml'):
54
+ chain = None
55
+ memory = None
56
+ if llm:
57
+ print("class_concept", class_concept)
58
+ if class_concept is None:
59
+ class_concept = 'AI assistant'
60
+ selection_config = OmegaConf.load(conf_file)
61
+ selected_model = selection_chain(llm, class_concept, selection_config['prompt'], selection_config['models'])
62
+ model_conf = selection_config['models'][selected_model]
63
+ return model_conf, selected_model
chat_anything/chatbot/voice_select.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import LLMChain
2
+ from langchain.prompts import PromptTemplate
3
+ from omegaconf import OmegaConf
4
+ import datetime
5
+
6
+ VOICE_SELECTION_PROMPT_TEMPLATE = """
7
+ Select one of the following voice based on the given concept.
8
+ You must choose one voice name based on the description of each model and the concept.
9
+
10
+
11
+ Cencept: {concept}
12
+
13
+ Voice name and description: {model_list}
14
+
15
+ Warning: {warning}
16
+
17
+ The avilable voice names:
18
+ {model_name_list}
19
+
20
+ Selected voice name:
21
+ """
22
+
23
+ GENDER_SELECTION_PROMPT_TEMPLATE = """
24
+ Select one of the following gender based on the given concept.
25
+ You must choose one gender based on the description of the concept. You must choose one gender Even if you can't decide.
26
+
27
+ Gender:
28
+ male
29
+ female
30
+
31
+ Cencept: {concept}
32
+ Selected gender male or female:
33
+ """
34
+
35
+ LANGUAGE_SELECTION_PROMPT_TEMPLATE = """
36
+ Select one of the following language based on the given concept.
37
+ You must choose the language that is used by the description of the concept.
38
+
39
+ Languages:
40
+ Chinese
41
+ English
42
+ Japanese
43
+
44
+ Cencept: {concept}
45
+ Selected language:
46
+ """
47
+
48
+ def load_voice_model_list():
49
+ models_config = OmegaConf.load('resources/voices.yaml')
50
+ models_dict = models_config['models']
51
+ print(models_dict)
52
+ model_list_str = ''
53
+ model_name_list_str = ''
54
+ for key, value in models_dict.items():
55
+ model_list_str+="model name: " +key+', model description: '+value['desc']+'\n'
56
+ model_name_list_str += key + ' '
57
+ model_name_list_str += '\n'
58
+ return model_list_str, models_dict, model_name_list_str
59
+
60
+ def get_vioce_model_chain(llm, class_concept):
61
+ model_template = PromptTemplate(
62
+ input_variables=["model_list", "concept", "model_name_list", "warning"],
63
+ template=VOICE_SELECTION_PROMPT_TEMPLATE,
64
+ )
65
+ model_list_str, models_dict, model_name_list_str = load_voice_model_list()
66
+
67
+ personality_chain = LLMChain(
68
+ llm=llm, prompt=model_template, verbose=True)
69
+
70
+ selected_model = None
71
+ while (selected_model is None) or not (selected_model in models_dict):
72
+ if (selected_model is not None) and not (selected_model in models_dict):
73
+ warning_str = '{} is not in Model list! \n'.format(selected_model)
74
+ else:
75
+ warning_str = ''
76
+ selected_model = personality_chain.run({'concept': class_concept, 'model_list':model_list_str, 'warning': warning_str, 'model_name_list': model_name_list_str})
77
+ print("Selected model name: ", selected_model)
78
+
79
+ return selected_model
80
+
81
+ def get_gender_chain(llm, class_concept):
82
+ model_template = PromptTemplate(
83
+ input_variables=["concept"],
84
+ template=GENDER_SELECTION_PROMPT_TEMPLATE,
85
+ )
86
+
87
+ personality_chain = LLMChain(
88
+ llm=llm, prompt=model_template, verbose=True)
89
+ selected_gender = personality_chain.run({'concept': class_concept})
90
+ print("Selected gender: ", selected_gender)
91
+ return selected_gender
92
+
93
+ def get_language_chain(llm, class_concept):
94
+ model_template = PromptTemplate(
95
+ input_variables=["concept"],
96
+ template=LANGUAGE_SELECTION_PROMPT_TEMPLATE,
97
+ )
98
+
99
+ personality_chain = LLMChain(
100
+ llm=llm, prompt=model_template, verbose=True)
101
+ selected_language = personality_chain.run({'concept': class_concept})
102
+ print("Selected language: ", selected_language)
103
+ return selected_language
104
+
105
+
106
+
107
+ def voice_selection_chain(llm, class_concept=None):
108
+ chain = None
109
+ memory = None
110
+ if llm:
111
+ print("class_concept", class_concept)
112
+ if class_concept is None:
113
+ class_concept = 'AI assistant'
114
+ selected_model = get_vioce_model_chain(llm, class_concept)
115
+ gender = get_gender_chain(llm, class_concept)
116
+ language = get_language_chain(llm, class_concept)
117
+
118
+ return selected_model, gender, language
119
+
chat_anything/face_generator/__init__.py ADDED
File without changes
chat_anything/face_generator/long_prompt_control_generator.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ from PIL import ImageDraw
4
+ import numpy as np
5
+
6
+ import dlib
7
+ import cv2
8
+ import torch
9
+
10
+ import diffusers
11
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
12
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline
13
+ from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline, get_weighted_text_embeddings
14
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler,DPMSolverMultistepScheduler # DPM++ SDE Karras
15
+
16
+ from chat_anything.face_generator.utils.generate import generate
17
+
18
+ from .long_prompt_generator import LongPromptGenerator
19
+
20
+ def draw_landmarks(image, landmarks, color="white", radius=2.5):
21
+ draw = ImageDraw.Draw(image)
22
+ for dot in landmarks:
23
+ x, y = dot
24
+ draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color)
25
+
26
+ def get_ldmk_img(w, h, ldmks) -> PIL.Image:
27
+ con_img = Image.new('RGB', (w, h), color=(0, 0, 0))
28
+ draw_landmarks(con_img, ldmks)
29
+ return con_img
30
+
31
+ class LongPromptControlGenerator(LongPromptGenerator):
32
+
33
+ def __init__(self, model_dir, lora_path, prompt_template, negative_prompt, face_control_dir, face_detect_path,):
34
+ self.face_control_dir = face_control_dir
35
+ self.face_detect_path = face_detect_path
36
+ super().__init__(model_dir, lora_path, prompt_template, negative_prompt)
37
+
38
+ def load_model(self, *args, **kwargs):
39
+ super().load_model(*args, **kwargs)
40
+ self.face_detector = dlib.get_frontal_face_detector()
41
+ self.face_predictor = dlib.shape_predictor(self.face_detect_path)
42
+ # load control net
43
+ face_controlnet = ControlNetModel.from_pretrained(self.face_control_dir).to('cuda', dtype=torch.float16)
44
+ self.face_control_pipe = StableDiffusionControlNetPipeline(controlnet=face_controlnet, **self.pipe.components)
45
+ self.face_control_img2img_pipe = StableDiffusionControlNetImg2ImgPipeline(controlnet=face_controlnet, **self.pipe.components)
46
+
47
+ def _get_68landmarks_seq(self, img_np):
48
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
49
+ faces = self.face_detector(gray)
50
+ landmarks = []
51
+ for face in faces:
52
+ shape = self.face_predictor(gray, face)
53
+ for i in range(68):
54
+ x = shape.part(i).x
55
+ y = shape.part(i).y
56
+ landmarks.append((x, y))
57
+ return landmarks
58
+
59
+ def has_face(self, img_pil):
60
+ img_np = np.array(img_pil)
61
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
62
+ faces = self.face_detector(gray)
63
+ return len(faces) != 0
64
+
65
+ def face_control_generate(
66
+ self,
67
+ prompt,
68
+ face_img_pil,
69
+ do_inversion=False,
70
+ **kwargs,
71
+ ):
72
+ """
73
+ Face control generating.
74
+ """
75
+ face_img_np = np.array(face_img_pil)
76
+ ldmk_seq = self._get_68landmarks_seq(face_img_np)
77
+ ldmk_img_pil = get_ldmk_img(face_img_pil.size[0], face_img_pil.size[1], ldmk_seq)
78
+ print('GENERATING:', prompt)
79
+
80
+ generating_conf = {
81
+ "prompt": prompt,
82
+ "negative_prompt": self.negative_prompt,
83
+ "num_inference_steps": 25,
84
+ "guidance_scale": 7,
85
+ "controlnet_conditioning_scale": kwargs.pop('controlnet_conditioning_scale', 1.0),
86
+ "generator": kwargs.pop('generator', None),
87
+ }
88
+
89
+ if not do_inversion:
90
+ generating_conf.update({
91
+ "pipe": self.face_control_pipe,
92
+ "image": ldmk_img_pil,
93
+ "controlnet_conditioning_scale": kwargs.pop('controlnet_conditioning_scale', 1.0),
94
+ })
95
+ else:
96
+ generating_conf.update({
97
+ "pipe": self.face_control_img2img_pipe,
98
+ "image": face_img_pil,
99
+ "control_image": ldmk_img_pil,
100
+ "strength": kwargs.pop('strength', 0.9),
101
+ })
102
+ pipe_out = generate(**generating_conf)
103
+ generated_img = pipe_out[0][0]
104
+ return generated_img
chat_anything/face_generator/long_prompt_generator.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ from PIL import ImageDraw
4
+ import numpy as np
5
+
6
+ import dlib
7
+ import cv2
8
+ import torch
9
+
10
+ import diffusers
11
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
12
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionImg2ImgPipeline
13
+ from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline, get_weighted_text_embeddings
14
+ from diffusers.schedulers import EulerAncestralDiscreteScheduler,DPMSolverMultistepScheduler # DPM++ SDE Karras
15
+
16
+ from chat_anything.face_generator.utils.generate import generate
17
+
18
+ class LongPromptGenerator():
19
+ prompt_template = "A portrait of a {}, fine face, nice looking"
20
+ negative_prompt = "easynegative,Low resolution,Low quality, Opened Mouth"
21
+ # negative_prompt = "(((sexy))),paintings,loli,,big head,sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples,extra fingers, ((extra arms)), (extra legs), mutated hands, (fused fingers), (too many fingers), (long neck:1.3)"
22
+
23
+ def __init__(self, model_dir, lora_path=None, prompt_template="{}", negative_prompt=""):
24
+ self.model_dir = model_dir
25
+ self.lora_path = lora_path
26
+ self.prompt_template = prompt_template
27
+ self.negative_prompt = negative_prompt
28
+
29
+ def load_model(self, *args, **kwargs):
30
+ # load model
31
+ try:
32
+ pipe = DiffusionPipeline.from_pretrained(self.model_dir, torch_dtype=torch.float16, **kwargs)
33
+ except:
34
+ pipe = StableDiffusionPipeline.from_pretrained(self.model_dir, torch_dtype=torch.float16, **kwargs)
35
+
36
+ pipe = pipe.to('cuda')
37
+ sche_conf = dict(pipe.scheduler.config)
38
+ fk_kwargs = ["skip_prk_steps","steps_offset","clip_sample","clip_sample_range","rescale_betas_zero_snr","timestep_spacing", "set_alpha_to_one"]
39
+ for k in fk_kwargs:
40
+ if k in sche_conf:
41
+ sche_conf.pop(k)
42
+ scheduler = DPMSolverMultistepScheduler(**sche_conf)
43
+ pipe.scheduler=scheduler
44
+ pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(**pipe.components)
45
+ self.pipe, self.pipe_longprompt = pipe, pipe_longprompt
46
+ if self.lora_path is not None:
47
+ pipe.load_lora_weights(self.lora_path)
48
+ self.pipe_img2img = StableDiffusionImg2ImgPipeline.from_pretrained(self.model_dir, **pipe.components)
49
+
50
+ def generate(
51
+ self,
52
+ prompt,
53
+ do_inversion=False,
54
+ **kwargs,
55
+ ):
56
+ """
57
+ Face control generating.
58
+ """
59
+ print('GENERATING:', prompt)
60
+ if not do_inversion:
61
+ generating_conf = {
62
+ "pipe": self.pipe,
63
+ "prompt": prompt,
64
+ "negative_prompt": self.negative_prompt,
65
+ "num_inference_steps": 25,
66
+ "guidance_scale": 7,
67
+ }
68
+ else:
69
+ assert 'image' in kwargs, 'doing inversion, prepare the init image please PIL Image'
70
+ init_image = kwargs['image']
71
+ generating_conf = {
72
+ "pipe": self.pipe_img2img,
73
+ "prompt": prompt,
74
+ "negative_prompt": self.negative_prompt,
75
+ "image": init_image,
76
+ "num_inference_steps": 25,
77
+ "guidance_scale": 7,
78
+ "strength": kwargs.pop('strength', 0.9),
79
+ }
80
+ pipe_out = generate(**generating_conf)
81
+ generated_img = pipe_out[0][0]
82
+ return generated_img
chat_anything/face_generator/pipelines/lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import re
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import torch
8
+ from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
+
11
+ from diffusers import DiffusionPipeline
12
+ from diffusers.configuration_utils import FrozenDict
13
+ from diffusers.image_processor import VaeImageProcessor
14
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
15
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
16
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+
20
+ from diffusers.utils import (
21
+ PIL_INTERPOLATION,
22
+ deprecate,
23
+ is_accelerate_available,
24
+ is_accelerate_version,
25
+ logging,
26
+ )
27
+
28
+
29
+ # ------------------------------------------------------------------------------
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ re_attention = re.compile(
34
+ r"""
35
+ \\\(|
36
+ \\\)|
37
+ \\\[|
38
+ \\]|
39
+ \\\\|
40
+ \\|
41
+ \(|
42
+ \[|
43
+ :([+-]?[.\d]+)\)|
44
+ \)|
45
+ ]|
46
+ [^\\()\[\]:]+|
47
+ :
48
+ """,
49
+ re.X,
50
+ )
51
+
52
+
53
+ def parse_prompt_attention(text):
54
+ """
55
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
56
+ Accepted tokens are:
57
+ (abc) - increases attention to abc by a multiplier of 1.1
58
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
59
+ [abc] - decreases attention to abc by a multiplier of 1.1
60
+ \( - literal character '('
61
+ \[ - literal character '['
62
+ \) - literal character ')'
63
+ \] - literal character ']'
64
+ \\ - literal character '\'
65
+ anything else - just text
66
+ >>> parse_prompt_attention('normal text')
67
+ [['normal text', 1.0]]
68
+ >>> parse_prompt_attention('an (important) word')
69
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
70
+ >>> parse_prompt_attention('(unbalanced')
71
+ [['unbalanced', 1.1]]
72
+ >>> parse_prompt_attention('\(literal\]')
73
+ [['(literal]', 1.0]]
74
+ >>> parse_prompt_attention('(unnecessary)(parens)')
75
+ [['unnecessaryparens', 1.1]]
76
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
77
+ [['a ', 1.0],
78
+ ['house', 1.5730000000000004],
79
+ [' ', 1.1],
80
+ ['on', 1.0],
81
+ [' a ', 1.1],
82
+ ['hill', 0.55],
83
+ [', sun, ', 1.1],
84
+ ['sky', 1.4641000000000006],
85
+ ['.', 1.1]]
86
+ """
87
+
88
+ res = []
89
+ round_brackets = []
90
+ square_brackets = []
91
+
92
+ round_bracket_multiplier = 1.1
93
+ square_bracket_multiplier = 1 / 1.1
94
+
95
+ def multiply_range(start_position, multiplier):
96
+ for p in range(start_position, len(res)):
97
+ res[p][1] *= multiplier
98
+
99
+ for m in re_attention.finditer(text):
100
+ text = m.group(0)
101
+ weight = m.group(1)
102
+
103
+ if text.startswith("\\"):
104
+ res.append([text[1:], 1.0])
105
+ elif text == "(":
106
+ round_brackets.append(len(res))
107
+ elif text == "[":
108
+ square_brackets.append(len(res))
109
+ elif weight is not None and len(round_brackets) > 0:
110
+ multiply_range(round_brackets.pop(), float(weight))
111
+ elif text == ")" and len(round_brackets) > 0:
112
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
113
+ elif text == "]" and len(square_brackets) > 0:
114
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
115
+ else:
116
+ res.append([text, 1.0])
117
+
118
+ for pos in round_brackets:
119
+ multiply_range(pos, round_bracket_multiplier)
120
+
121
+ for pos in square_brackets:
122
+ multiply_range(pos, square_bracket_multiplier)
123
+
124
+ if len(res) == 0:
125
+ res = [["", 1.0]]
126
+
127
+ # merge runs of identical weights
128
+ i = 0
129
+ while i + 1 < len(res):
130
+ if res[i][1] == res[i + 1][1]:
131
+ res[i][0] += res[i + 1][0]
132
+ res.pop(i + 1)
133
+ else:
134
+ i += 1
135
+
136
+ return res
137
+
138
+
139
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
140
+ r"""
141
+ Tokenize a list of prompts and return its tokens with weights of each token.
142
+
143
+ No padding, starting or ending token is included.
144
+ """
145
+ tokens = []
146
+ weights = []
147
+ truncated = False
148
+ for text in prompt:
149
+ texts_and_weights = parse_prompt_attention(text)
150
+ text_token = []
151
+ text_weight = []
152
+ for word, weight in texts_and_weights:
153
+ # tokenize and discard the starting and the ending token
154
+ token = pipe.tokenizer(word).input_ids[1:-1]
155
+ text_token += token
156
+ # copy the weight by length of token
157
+ text_weight += [weight] * len(token)
158
+ # stop if the text is too long (longer than truncation limit)
159
+ if len(text_token) > max_length:
160
+ truncated = True
161
+ break
162
+ # truncate
163
+ if len(text_token) > max_length:
164
+ truncated = True
165
+ text_token = text_token[:max_length]
166
+ text_weight = text_weight[:max_length]
167
+ tokens.append(text_token)
168
+ weights.append(text_weight)
169
+ if truncated:
170
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
171
+ return tokens, weights
172
+
173
+
174
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
175
+ r"""
176
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
177
+ """
178
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
179
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
180
+ for i in range(len(tokens)):
181
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
182
+ if no_boseos_middle:
183
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
184
+ else:
185
+ w = []
186
+ if len(weights[i]) == 0:
187
+ w = [1.0] * weights_length
188
+ else:
189
+ for j in range(max_embeddings_multiples):
190
+ w.append(1.0) # weight for starting token in this chunk
191
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
192
+ w.append(1.0) # weight for ending token in this chunk
193
+ w += [1.0] * (weights_length - len(w))
194
+ weights[i] = w[:]
195
+
196
+ return tokens, weights
197
+
198
+
199
+ def get_unweighted_text_embeddings(
200
+ pipe: DiffusionPipeline,
201
+ text_input: torch.Tensor,
202
+ chunk_length: int,
203
+ no_boseos_middle: Optional[bool] = True,
204
+ ):
205
+ """
206
+ When the length of tokens is a multiple of the capacity of the text encoder,
207
+ it should be split into chunks and sent to the text encoder individually.
208
+ """
209
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
210
+ if max_embeddings_multiples > 1:
211
+ text_embeddings = []
212
+ for i in range(max_embeddings_multiples):
213
+ # extract the i-th chunk
214
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
215
+
216
+ # cover the head and the tail by the starting and the ending tokens
217
+ text_input_chunk[:, 0] = text_input[0, 0]
218
+ text_input_chunk[:, -1] = text_input[0, -1]
219
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
220
+
221
+ if no_boseos_middle:
222
+ if i == 0:
223
+ # discard the ending token
224
+ text_embedding = text_embedding[:, :-1]
225
+ elif i == max_embeddings_multiples - 1:
226
+ # discard the starting token
227
+ text_embedding = text_embedding[:, 1:]
228
+ else:
229
+ # discard both starting and ending tokens
230
+ text_embedding = text_embedding[:, 1:-1]
231
+
232
+ text_embeddings.append(text_embedding)
233
+ text_embeddings = torch.concat(text_embeddings, axis=1)
234
+ else:
235
+ text_embeddings = pipe.text_encoder(text_input)[0]
236
+ return text_embeddings
237
+
238
+
239
+ def get_weighted_text_embeddings(
240
+ pipe: DiffusionPipeline,
241
+ prompt: Union[str, List[str]],
242
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
243
+ max_embeddings_multiples: Optional[int] = 3,
244
+ no_boseos_middle: Optional[bool] = False,
245
+ skip_parsing: Optional[bool] = False,
246
+ skip_weighting: Optional[bool] = False,
247
+ ):
248
+ r"""
249
+ Prompts can be assigned with local weights using brackets. For example,
250
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
251
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
252
+
253
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
254
+
255
+ Args:
256
+ pipe (`DiffusionPipeline`):
257
+ Pipe to provide access to the tokenizer and the text encoder.
258
+ prompt (`str` or `List[str]`):
259
+ The prompt or prompts to guide the image generation.
260
+ uncond_prompt (`str` or `List[str]`):
261
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
262
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
263
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
264
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
265
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
266
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
267
+ ending token in each of the chunk in the middle.
268
+ skip_parsing (`bool`, *optional*, defaults to `False`):
269
+ Skip the parsing of brackets.
270
+ skip_weighting (`bool`, *optional*, defaults to `False`):
271
+ Skip the weighting. When the parsing is skipped, it is forced True.
272
+ """
273
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
274
+ if isinstance(prompt, str):
275
+ prompt = [prompt]
276
+
277
+ if not skip_parsing:
278
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
279
+ if uncond_prompt is not None:
280
+ if isinstance(uncond_prompt, str):
281
+ uncond_prompt = [uncond_prompt]
282
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
283
+ else:
284
+ prompt_tokens = [
285
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
286
+ ]
287
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
288
+ if uncond_prompt is not None:
289
+ if isinstance(uncond_prompt, str):
290
+ uncond_prompt = [uncond_prompt]
291
+ uncond_tokens = [
292
+ token[1:-1]
293
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
294
+ ]
295
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
296
+
297
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
298
+ max_length = max([len(token) for token in prompt_tokens])
299
+ if uncond_prompt is not None:
300
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
301
+
302
+ max_embeddings_multiples = min(
303
+ max_embeddings_multiples,
304
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
305
+ )
306
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
307
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
308
+
309
+ # pad the length of tokens and weights
310
+ bos = pipe.tokenizer.bos_token_id
311
+ eos = pipe.tokenizer.eos_token_id
312
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
313
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
314
+ prompt_tokens,
315
+ prompt_weights,
316
+ max_length,
317
+ bos,
318
+ eos,
319
+ pad,
320
+ no_boseos_middle=no_boseos_middle,
321
+ chunk_length=pipe.tokenizer.model_max_length,
322
+ )
323
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
324
+ if uncond_prompt is not None:
325
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
326
+ uncond_tokens,
327
+ uncond_weights,
328
+ max_length,
329
+ bos,
330
+ eos,
331
+ pad,
332
+ no_boseos_middle=no_boseos_middle,
333
+ chunk_length=pipe.tokenizer.model_max_length,
334
+ )
335
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
336
+
337
+ # get the embeddings
338
+ text_embeddings = get_unweighted_text_embeddings(
339
+ pipe,
340
+ prompt_tokens,
341
+ pipe.tokenizer.model_max_length,
342
+ no_boseos_middle=no_boseos_middle,
343
+ )
344
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
345
+ if uncond_prompt is not None:
346
+ uncond_embeddings = get_unweighted_text_embeddings(
347
+ pipe,
348
+ uncond_tokens,
349
+ pipe.tokenizer.model_max_length,
350
+ no_boseos_middle=no_boseos_middle,
351
+ )
352
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
353
+
354
+ # assign weights to the prompts and normalize in the sense of mean
355
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
356
+ if (not skip_parsing) and (not skip_weighting):
357
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
358
+ text_embeddings *= prompt_weights.unsqueeze(-1)
359
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
360
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
361
+ if uncond_prompt is not None:
362
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
363
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
364
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
365
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
366
+
367
+ if uncond_prompt is not None:
368
+ return text_embeddings, uncond_embeddings
369
+ return text_embeddings, None
370
+
371
+
372
+ def preprocess_image(image, batch_size):
373
+ w, h = image.size
374
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
375
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
376
+ image = np.array(image).astype(np.float32) / 255.0
377
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
378
+ image = torch.from_numpy(image)
379
+ return 2.0 * image - 1.0
380
+
381
+
382
+ def preprocess_mask(mask, batch_size, scale_factor=8):
383
+ if not isinstance(mask, torch.FloatTensor):
384
+ mask = mask.convert("L")
385
+ w, h = mask.size
386
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
387
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
388
+ mask = np.array(mask).astype(np.float32) / 255.0
389
+ mask = np.tile(mask, (4, 1, 1))
390
+ mask = np.vstack([mask[None]] * batch_size)
391
+ mask = 1 - mask # repaint white, keep black
392
+ mask = torch.from_numpy(mask)
393
+ return mask
394
+
395
+ else:
396
+ valid_mask_channel_sizes = [1, 3]
397
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
398
+ if mask.shape[3] in valid_mask_channel_sizes:
399
+ mask = mask.permute(0, 3, 1, 2)
400
+ elif mask.shape[1] not in valid_mask_channel_sizes:
401
+ raise ValueError(
402
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
403
+ f" but received mask of shape {tuple(mask.shape)}"
404
+ )
405
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
406
+ mask = mask.mean(dim=1, keepdim=True)
407
+ h, w = mask.shape[-2:]
408
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
409
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
410
+ return mask
411
+
412
+
413
+ class StableDiffusionLongPromptWeightingPipeline(
414
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
415
+ ):
416
+ r"""
417
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
418
+ weighting in prompt.
419
+
420
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
421
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
422
+
423
+ Args:
424
+ vae ([`AutoencoderKL`]):
425
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
426
+ text_encoder ([`CLIPTextModel`]):
427
+ Frozen text-encoder. Stable Diffusion uses the text portion of
428
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
429
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
430
+ tokenizer (`CLIPTokenizer`):
431
+ Tokenizer of class
432
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
433
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
434
+ scheduler ([`SchedulerMixin`]):
435
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
436
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
437
+ safety_checker ([`StableDiffusionSafetyChecker`]):
438
+ Classification module that estimates whether generated images could be considered offensive or harmful.
439
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
440
+ feature_extractor ([`CLIPImageProcessor`]):
441
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
442
+ """
443
+
444
+ _optional_components = ["safety_checker", "feature_extractor"]
445
+
446
+ def __init__(
447
+ self,
448
+ vae: AutoencoderKL,
449
+ text_encoder: CLIPTextModel,
450
+ tokenizer: CLIPTokenizer,
451
+ unet: UNet2DConditionModel,
452
+ scheduler: KarrasDiffusionSchedulers,
453
+ safety_checker: StableDiffusionSafetyChecker,
454
+ feature_extractor: CLIPImageProcessor,
455
+ requires_safety_checker: bool = True,
456
+ ):
457
+ super().__init__()
458
+
459
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
460
+ deprecation_message = (
461
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
462
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
463
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
464
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
465
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
466
+ " file"
467
+ )
468
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
469
+ new_config = dict(scheduler.config)
470
+ new_config["steps_offset"] = 1
471
+ scheduler._internal_dict = FrozenDict(new_config)
472
+
473
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
474
+ deprecation_message = (
475
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
476
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
477
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
478
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
479
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
480
+ )
481
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
482
+ new_config = dict(scheduler.config)
483
+ new_config["clip_sample"] = False
484
+ scheduler._internal_dict = FrozenDict(new_config)
485
+
486
+ if safety_checker is None and requires_safety_checker:
487
+ logger.warning(
488
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
489
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
490
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
491
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
492
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
493
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
494
+ )
495
+
496
+ if safety_checker is not None and feature_extractor is None:
497
+ raise ValueError(
498
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
499
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
500
+ )
501
+
502
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
503
+ version.parse(unet.config._diffusers_version).base_version
504
+ ) < version.parse("0.9.0.dev0")
505
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
506
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
507
+ deprecation_message = (
508
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
509
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
510
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
511
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
512
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
513
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
514
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
515
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
516
+ " the `unet/config.json` file"
517
+ )
518
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
519
+ new_config = dict(unet.config)
520
+ new_config["sample_size"] = 64
521
+ unet._internal_dict = FrozenDict(new_config)
522
+ self.register_modules(
523
+ vae=vae,
524
+ text_encoder=text_encoder,
525
+ tokenizer=tokenizer,
526
+ unet=unet,
527
+ scheduler=scheduler,
528
+ safety_checker=safety_checker,
529
+ feature_extractor=feature_extractor,
530
+ )
531
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
532
+
533
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
534
+ self.register_to_config(
535
+ requires_safety_checker=requires_safety_checker,
536
+ )
537
+
538
+ def enable_vae_slicing(self):
539
+ r"""
540
+ Enable sliced VAE decoding.
541
+
542
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
543
+ steps. This is useful to save some memory and allow larger batch sizes.
544
+ """
545
+ self.vae.enable_slicing()
546
+
547
+ def disable_vae_slicing(self):
548
+ r"""
549
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
550
+ computing decoding in one step.
551
+ """
552
+ self.vae.disable_slicing()
553
+
554
+ def enable_vae_tiling(self):
555
+ r"""
556
+ Enable tiled VAE decoding.
557
+
558
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
559
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
560
+ """
561
+ self.vae.enable_tiling()
562
+
563
+ def disable_vae_tiling(self):
564
+ r"""
565
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
566
+ computing decoding in one step.
567
+ """
568
+ self.vae.disable_tiling()
569
+
570
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
571
+ def enable_sequential_cpu_offload(self, gpu_id=0):
572
+ r"""
573
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
574
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
575
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
576
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
577
+ `enable_model_cpu_offload`, but performance is lower.
578
+ """
579
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
580
+ from accelerate import cpu_offload
581
+ else:
582
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
583
+
584
+ device = torch.device(f"cuda:{gpu_id}")
585
+
586
+ if self.device.type != "cpu":
587
+ self.to("cpu", silence_dtype_warnings=True)
588
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
589
+
590
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
591
+ cpu_offload(cpu_offloaded_model, device)
592
+
593
+ if self.safety_checker is not None:
594
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
595
+
596
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
597
+ def enable_model_cpu_offload(self, gpu_id=0):
598
+ r"""
599
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
600
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
601
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
602
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
603
+ """
604
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
605
+ from accelerate import cpu_offload_with_hook
606
+ else:
607
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
608
+
609
+ device = torch.device(f"cuda:{gpu_id}")
610
+
611
+ if self.device.type != "cpu":
612
+ self.to("cpu", silence_dtype_warnings=True)
613
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
614
+
615
+ hook = None
616
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
617
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
618
+
619
+ if self.safety_checker is not None:
620
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
621
+
622
+ # We'll offload the last model manually.
623
+ self.final_offload_hook = hook
624
+
625
+ @property
626
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
627
+ def _execution_device(self):
628
+ r"""
629
+ Returns the device on which the pipeline's models will be executed. After calling
630
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
631
+ hooks.
632
+ """
633
+ if not hasattr(self.unet, "_hf_hook"):
634
+ return self.device
635
+ for module in self.unet.modules():
636
+ if (
637
+ hasattr(module, "_hf_hook")
638
+ and hasattr(module._hf_hook, "execution_device")
639
+ and module._hf_hook.execution_device is not None
640
+ ):
641
+ return torch.device(module._hf_hook.execution_device)
642
+ return self.device
643
+
644
+ def _encode_prompt(
645
+ self,
646
+ prompt,
647
+ device,
648
+ num_images_per_prompt,
649
+ do_classifier_free_guidance,
650
+ negative_prompt=None,
651
+ max_embeddings_multiples=3,
652
+ prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
654
+ ):
655
+ r"""
656
+ Encodes the prompt into text encoder hidden states.
657
+
658
+ Args:
659
+ prompt (`str` or `list(int)`):
660
+ prompt to be encoded
661
+ device: (`torch.device`):
662
+ torch device
663
+ num_images_per_prompt (`int`):
664
+ number of images that should be generated per prompt
665
+ do_classifier_free_guidance (`bool`):
666
+ whether to use classifier free guidance or not
667
+ negative_prompt (`str` or `List[str]`):
668
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
669
+ if `guidance_scale` is less than `1`).
670
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
671
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
672
+ """
673
+ if prompt is not None and isinstance(prompt, str):
674
+ batch_size = 1
675
+ elif prompt is not None and isinstance(prompt, list):
676
+ batch_size = len(prompt)
677
+ else:
678
+ batch_size = prompt_embeds.shape[0]
679
+
680
+ if negative_prompt_embeds is None:
681
+ if negative_prompt is None:
682
+ negative_prompt = [""] * batch_size
683
+ elif isinstance(negative_prompt, str):
684
+ negative_prompt = [negative_prompt] * batch_size
685
+ if batch_size != len(negative_prompt):
686
+ raise ValueError(
687
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
688
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
689
+ " the batch size of `prompt`."
690
+ )
691
+ if prompt_embeds is None or negative_prompt_embeds is None:
692
+ if isinstance(self, TextualInversionLoaderMixin):
693
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
694
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
695
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
696
+
697
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
698
+ pipe=self,
699
+ prompt=prompt,
700
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
701
+ max_embeddings_multiples=max_embeddings_multiples,
702
+ )
703
+ if prompt_embeds is None:
704
+ prompt_embeds = prompt_embeds1
705
+ if negative_prompt_embeds is None:
706
+ negative_prompt_embeds = negative_prompt_embeds1
707
+
708
+ bs_embed, seq_len, _ = prompt_embeds.shape
709
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
710
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
711
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
712
+
713
+ if do_classifier_free_guidance:
714
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
715
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
716
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
717
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
718
+
719
+ return prompt_embeds
720
+
721
+ def check_inputs(
722
+ self,
723
+ prompt,
724
+ height,
725
+ width,
726
+ strength,
727
+ callback_steps,
728
+ negative_prompt=None,
729
+ prompt_embeds=None,
730
+ negative_prompt_embeds=None,
731
+ ):
732
+ if height % 8 != 0 or width % 8 != 0:
733
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
734
+
735
+ if strength < 0 or strength > 1:
736
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
737
+
738
+ if (callback_steps is None) or (
739
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
740
+ ):
741
+ raise ValueError(
742
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
743
+ f" {type(callback_steps)}."
744
+ )
745
+
746
+ if prompt is not None and prompt_embeds is not None:
747
+ raise ValueError(
748
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
749
+ " only forward one of the two."
750
+ )
751
+ elif prompt is None and prompt_embeds is None:
752
+ raise ValueError(
753
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
754
+ )
755
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
756
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
757
+
758
+ if negative_prompt is not None and negative_prompt_embeds is not None:
759
+ raise ValueError(
760
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
761
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
762
+ )
763
+
764
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
765
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
766
+ raise ValueError(
767
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
768
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
769
+ f" {negative_prompt_embeds.shape}."
770
+ )
771
+
772
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
773
+ if is_text2img:
774
+ return self.scheduler.timesteps.to(device), num_inference_steps
775
+ else:
776
+ # get the original timestep using init_timestep
777
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
778
+
779
+ t_start = max(num_inference_steps - init_timestep, 0)
780
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
781
+
782
+ return timesteps, num_inference_steps - t_start
783
+
784
+ def run_safety_checker(self, image, device, dtype):
785
+ if self.safety_checker is not None:
786
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
787
+ image, has_nsfw_concept = self.safety_checker(
788
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
789
+ )
790
+ else:
791
+ has_nsfw_concept = None
792
+ return image, has_nsfw_concept
793
+
794
+ def decode_latents(self, latents):
795
+ latents = 1 / self.vae.config.scaling_factor * latents
796
+ image = self.vae.decode(latents).sample
797
+ image = (image / 2 + 0.5).clamp(0, 1)
798
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
799
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
800
+ return image
801
+
802
+ def prepare_extra_step_kwargs(self, generator, eta):
803
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
804
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
805
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
806
+ # and should be between [0, 1]
807
+
808
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
809
+ extra_step_kwargs = {}
810
+ if accepts_eta:
811
+ extra_step_kwargs["eta"] = eta
812
+
813
+ # check if the scheduler accepts generator
814
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
815
+ if accepts_generator:
816
+ extra_step_kwargs["generator"] = generator
817
+ return extra_step_kwargs
818
+
819
+ def prepare_latents(
820
+ self,
821
+ image,
822
+ timestep,
823
+ num_images_per_prompt,
824
+ batch_size,
825
+ num_channels_latents,
826
+ height,
827
+ width,
828
+ dtype,
829
+ device,
830
+ generator,
831
+ latents=None,
832
+ ):
833
+ if image is None:
834
+ batch_size = batch_size * num_images_per_prompt
835
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
836
+ if isinstance(generator, list) and len(generator) != batch_size:
837
+ raise ValueError(
838
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
839
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
840
+ )
841
+
842
+ if latents is None:
843
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
844
+ else:
845
+ latents = latents.to(device)
846
+
847
+ # scale the initial noise by the standard deviation required by the scheduler
848
+ latents = latents * self.scheduler.init_noise_sigma
849
+ return latents, None, None
850
+ else:
851
+ image = image.to(device=self.device, dtype=dtype)
852
+ init_latent_dist = self.vae.encode(image).latent_dist
853
+ init_latents = init_latent_dist.sample(generator=generator)
854
+ init_latents = self.vae.config.scaling_factor * init_latents
855
+
856
+ # Expand init_latents for batch_size and num_images_per_prompt
857
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
858
+ init_latents_orig = init_latents
859
+
860
+ # add noise to latents using the timesteps
861
+ noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
862
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
863
+ latents = init_latents
864
+ return latents, init_latents_orig, noise
865
+
866
+ @torch.no_grad()
867
+ def __call__(
868
+ self,
869
+ prompt: Union[str, List[str]],
870
+ negative_prompt: Optional[Union[str, List[str]]] = None,
871
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
872
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
873
+ height: int = 512,
874
+ width: int = 512,
875
+ num_inference_steps: int = 50,
876
+ guidance_scale: float = 7.5,
877
+ strength: float = 0.8,
878
+ num_images_per_prompt: Optional[int] = 1,
879
+ add_predicted_noise: Optional[bool] = False,
880
+ eta: float = 0.0,
881
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
882
+ latents: Optional[torch.FloatTensor] = None,
883
+ prompt_embeds: Optional[torch.FloatTensor] = None,
884
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
885
+ max_embeddings_multiples: Optional[int] = 3,
886
+ output_type: Optional[str] = "pil",
887
+ return_dict: bool = True,
888
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
889
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
890
+ callback_steps: int = 1,
891
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
892
+ ):
893
+ r"""
894
+ Function invoked when calling the pipeline for generation.
895
+
896
+ Args:
897
+ prompt (`str` or `List[str]`):
898
+ The prompt or prompts to guide the image generation.
899
+ negative_prompt (`str` or `List[str]`, *optional*):
900
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
901
+ if `guidance_scale` is less than `1`).
902
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
903
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
904
+ process.
905
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
906
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
907
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
908
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
909
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
910
+ height (`int`, *optional*, defaults to 512):
911
+ The height in pixels of the generated image.
912
+ width (`int`, *optional*, defaults to 512):
913
+ The width in pixels of the generated image.
914
+ num_inference_steps (`int`, *optional*, defaults to 50):
915
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
916
+ expense of slower inference.
917
+ guidance_scale (`float`, *optional*, defaults to 7.5):
918
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
919
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
920
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
921
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
922
+ usually at the expense of lower image quality.
923
+ strength (`float`, *optional*, defaults to 0.8):
924
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
925
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
926
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
927
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
928
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
929
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
930
+ The number of images to generate per prompt.
931
+ add_predicted_noise (`bool`, *optional*, defaults to True):
932
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
933
+ the reverse diffusion process
934
+ eta (`float`, *optional*, defaults to 0.0):
935
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
936
+ [`schedulers.DDIMScheduler`], will be ignored for others.
937
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
938
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
939
+ to make generation deterministic.
940
+ latents (`torch.FloatTensor`, *optional*):
941
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
942
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
943
+ tensor will ge generated by sampling using the supplied random `generator`.
944
+ prompt_embeds (`torch.FloatTensor`, *optional*):
945
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
946
+ provided, text embeddings will be generated from `prompt` input argument.
947
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
948
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
949
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
950
+ argument.
951
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
952
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
953
+ output_type (`str`, *optional*, defaults to `"pil"`):
954
+ The output format of the generate image. Choose between
955
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
956
+ return_dict (`bool`, *optional*, defaults to `True`):
957
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
958
+ plain tuple.
959
+ callback (`Callable`, *optional*):
960
+ A function that will be called every `callback_steps` steps during inference. The function will be
961
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
962
+ is_cancelled_callback (`Callable`, *optional*):
963
+ A function that will be called every `callback_steps` steps during inference. If the function returns
964
+ `True`, the inference will be cancelled.
965
+ callback_steps (`int`, *optional*, defaults to 1):
966
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
967
+ called at every step.
968
+ cross_attention_kwargs (`dict`, *optional*):
969
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
970
+ `self.processor` in
971
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
972
+
973
+ Returns:
974
+ `None` if cancelled by `is_cancelled_callback`,
975
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
976
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
977
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
978
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
979
+ (nsfw) content, according to the `safety_checker`.
980
+ """
981
+ # 0. Default height and width to unet
982
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
983
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
984
+
985
+ # 1. Check inputs. Raise error if not correct
986
+ self.check_inputs(
987
+ prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
988
+ )
989
+
990
+ # 2. Define call parameters
991
+ if prompt is not None and isinstance(prompt, str):
992
+ batch_size = 1
993
+ elif prompt is not None and isinstance(prompt, list):
994
+ batch_size = len(prompt)
995
+ else:
996
+ batch_size = prompt_embeds.shape[0]
997
+
998
+ device = self._execution_device
999
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1000
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1001
+ # corresponds to doing no classifier free guidance.
1002
+ do_classifier_free_guidance = guidance_scale > 1.0
1003
+
1004
+ # 3. Encode input prompt
1005
+ prompt_embeds = self._encode_prompt(
1006
+ prompt,
1007
+ device,
1008
+ num_images_per_prompt,
1009
+ do_classifier_free_guidance,
1010
+ negative_prompt,
1011
+ max_embeddings_multiples,
1012
+ prompt_embeds=prompt_embeds,
1013
+ negative_prompt_embeds=negative_prompt_embeds,
1014
+ )
1015
+ dtype = prompt_embeds.dtype
1016
+
1017
+ # 4. Preprocess image and mask
1018
+ if isinstance(image, PIL.Image.Image):
1019
+ image = preprocess_image(image, batch_size)
1020
+ if image is not None:
1021
+ image = image.to(device=self.device, dtype=dtype)
1022
+ if isinstance(mask_image, PIL.Image.Image):
1023
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
1024
+ if mask_image is not None:
1025
+ mask = mask_image.to(device=self.device, dtype=dtype)
1026
+ mask = torch.cat([mask] * num_images_per_prompt)
1027
+ else:
1028
+ mask = None
1029
+
1030
+ # 5. set timesteps
1031
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1032
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
1033
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1034
+
1035
+ # 6. Prepare latent variables
1036
+ latents, init_latents_orig, noise = self.prepare_latents(
1037
+ image,
1038
+ latent_timestep,
1039
+ num_images_per_prompt,
1040
+ batch_size,
1041
+ self.unet.config.in_channels,
1042
+ height,
1043
+ width,
1044
+ dtype,
1045
+ device,
1046
+ generator,
1047
+ latents,
1048
+ )
1049
+
1050
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1051
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1052
+
1053
+ # 8. Denoising loop
1054
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1055
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1056
+ for i, t in enumerate(timesteps):
1057
+ # expand the latents if we are doing classifier free guidance
1058
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1059
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1060
+
1061
+ # predict the noise residual
1062
+ noise_pred = self.unet(
1063
+ latent_model_input,
1064
+ t,
1065
+ encoder_hidden_states=prompt_embeds,
1066
+ cross_attention_kwargs=cross_attention_kwargs,
1067
+ ).sample
1068
+
1069
+ # perform guidance
1070
+ if do_classifier_free_guidance:
1071
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1072
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1073
+
1074
+ # compute the previous noisy sample x_t -> x_t-1
1075
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1076
+
1077
+ if mask is not None:
1078
+ # masking
1079
+ if add_predicted_noise:
1080
+ init_latents_proper = self.scheduler.add_noise(
1081
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
1082
+ )
1083
+ else:
1084
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1085
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1086
+
1087
+ # call the callback, if provided
1088
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1089
+ progress_bar.update()
1090
+ if i % callback_steps == 0:
1091
+ if callback is not None:
1092
+ callback(i, t, latents)
1093
+ if is_cancelled_callback is not None and is_cancelled_callback():
1094
+ return None
1095
+
1096
+ if output_type == "latent":
1097
+ image = latents
1098
+ has_nsfw_concept = None
1099
+ elif output_type == "pil":
1100
+ # 9. Post-processing
1101
+ image = self.decode_latents(latents)
1102
+
1103
+ # 10. Run safety checker
1104
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1105
+
1106
+ # 11. Convert to PIL
1107
+ image = self.numpy_to_pil(image)
1108
+ else:
1109
+ # 9. Post-processing
1110
+ image = self.decode_latents(latents)
1111
+
1112
+ # 10. Run safety checker
1113
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1114
+
1115
+ # Offload last model to CPU
1116
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1117
+ self.final_offload_hook.offload()
1118
+
1119
+ if not return_dict:
1120
+ return image, has_nsfw_concept
1121
+
1122
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1123
+
1124
+ def text2img(
1125
+ self,
1126
+ prompt: Union[str, List[str]],
1127
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1128
+ height: int = 512,
1129
+ width: int = 512,
1130
+ num_inference_steps: int = 50,
1131
+ guidance_scale: float = 7.5,
1132
+ num_images_per_prompt: Optional[int] = 1,
1133
+ eta: float = 0.0,
1134
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1135
+ latents: Optional[torch.FloatTensor] = None,
1136
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1137
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1138
+ max_embeddings_multiples: Optional[int] = 3,
1139
+ output_type: Optional[str] = "pil",
1140
+ return_dict: bool = True,
1141
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1142
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1143
+ callback_steps: int = 1,
1144
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1145
+ ):
1146
+ r"""
1147
+ Function for text-to-image generation.
1148
+ Args:
1149
+ prompt (`str` or `List[str]`):
1150
+ The prompt or prompts to guide the image generation.
1151
+ negative_prompt (`str` or `List[str]`, *optional*):
1152
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1153
+ if `guidance_scale` is less than `1`).
1154
+ height (`int`, *optional*, defaults to 512):
1155
+ The height in pixels of the generated image.
1156
+ width (`int`, *optional*, defaults to 512):
1157
+ The width in pixels of the generated image.
1158
+ num_inference_steps (`int`, *optional*, defaults to 50):
1159
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1160
+ expense of slower inference.
1161
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1162
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1163
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1164
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1165
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1166
+ usually at the expense of lower image quality.
1167
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1168
+ The number of images to generate per prompt.
1169
+ eta (`float`, *optional*, defaults to 0.0):
1170
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1171
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1172
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1173
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1174
+ to make generation deterministic.
1175
+ latents (`torch.FloatTensor`, *optional*):
1176
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1177
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1178
+ tensor will ge generated by sampling using the supplied random `generator`.
1179
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1180
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1181
+ provided, text embeddings will be generated from `prompt` input argument.
1182
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1183
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1184
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1185
+ argument.
1186
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1187
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1188
+ output_type (`str`, *optional*, defaults to `"pil"`):
1189
+ The output format of the generate image. Choose between
1190
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1191
+ return_dict (`bool`, *optional*, defaults to `True`):
1192
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1193
+ plain tuple.
1194
+ callback (`Callable`, *optional*):
1195
+ A function that will be called every `callback_steps` steps during inference. The function will be
1196
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1197
+ is_cancelled_callback (`Callable`, *optional*):
1198
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1199
+ `True`, the inference will be cancelled.
1200
+ callback_steps (`int`, *optional*, defaults to 1):
1201
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1202
+ called at every step.
1203
+ cross_attention_kwargs (`dict`, *optional*):
1204
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1205
+ `self.processor` in
1206
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1207
+
1208
+ Returns:
1209
+ `None` if cancelled by `is_cancelled_callback`,
1210
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1211
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1212
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1213
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1214
+ (nsfw) content, according to the `safety_checker`.
1215
+ """
1216
+ return self.__call__(
1217
+ prompt=prompt,
1218
+ negative_prompt=negative_prompt,
1219
+ height=height,
1220
+ width=width,
1221
+ num_inference_steps=num_inference_steps,
1222
+ guidance_scale=guidance_scale,
1223
+ num_images_per_prompt=num_images_per_prompt,
1224
+ eta=eta,
1225
+ generator=generator,
1226
+ latents=latents,
1227
+ prompt_embeds=prompt_embeds,
1228
+ negative_prompt_embeds=negative_prompt_embeds,
1229
+ max_embeddings_multiples=max_embeddings_multiples,
1230
+ output_type=output_type,
1231
+ return_dict=return_dict,
1232
+ callback=callback,
1233
+ is_cancelled_callback=is_cancelled_callback,
1234
+ callback_steps=callback_steps,
1235
+ cross_attention_kwargs=cross_attention_kwargs,
1236
+ )
1237
+
1238
+ def img2img(
1239
+ self,
1240
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1241
+ prompt: Union[str, List[str]],
1242
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1243
+ strength: float = 0.8,
1244
+ num_inference_steps: Optional[int] = 50,
1245
+ guidance_scale: Optional[float] = 7.5,
1246
+ num_images_per_prompt: Optional[int] = 1,
1247
+ eta: Optional[float] = 0.0,
1248
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1249
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1250
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1251
+ max_embeddings_multiples: Optional[int] = 3,
1252
+ output_type: Optional[str] = "pil",
1253
+ return_dict: bool = True,
1254
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1255
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1256
+ callback_steps: int = 1,
1257
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1258
+ ):
1259
+ r"""
1260
+ Function for image-to-image generation.
1261
+ Args:
1262
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1263
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1264
+ process.
1265
+ prompt (`str` or `List[str]`):
1266
+ The prompt or prompts to guide the image generation.
1267
+ negative_prompt (`str` or `List[str]`, *optional*):
1268
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1269
+ if `guidance_scale` is less than `1`).
1270
+ strength (`float`, *optional*, defaults to 0.8):
1271
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1272
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1273
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1274
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1275
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1276
+ num_inference_steps (`int`, *optional*, defaults to 50):
1277
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1278
+ expense of slower inference. This parameter will be modulated by `strength`.
1279
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1280
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1281
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1282
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1283
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1284
+ usually at the expense of lower image quality.
1285
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1286
+ The number of images to generate per prompt.
1287
+ eta (`float`, *optional*, defaults to 0.0):
1288
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1289
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1290
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1291
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1292
+ to make generation deterministic.
1293
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1294
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1295
+ provided, text embeddings will be generated from `prompt` input argument.
1296
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1297
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1298
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1299
+ argument.
1300
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1301
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1302
+ output_type (`str`, *optional*, defaults to `"pil"`):
1303
+ The output format of the generate image. Choose between
1304
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1305
+ return_dict (`bool`, *optional*, defaults to `True`):
1306
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1307
+ plain tuple.
1308
+ callback (`Callable`, *optional*):
1309
+ A function that will be called every `callback_steps` steps during inference. The function will be
1310
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1311
+ is_cancelled_callback (`Callable`, *optional*):
1312
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1313
+ `True`, the inference will be cancelled.
1314
+ callback_steps (`int`, *optional*, defaults to 1):
1315
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1316
+ called at every step.
1317
+ cross_attention_kwargs (`dict`, *optional*):
1318
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1319
+ `self.processor` in
1320
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1321
+
1322
+ Returns:
1323
+ `None` if cancelled by `is_cancelled_callback`,
1324
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1325
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1326
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1327
+ (nsfw) content, according to the `safety_checker`.
1328
+ """
1329
+ return self.__call__(
1330
+ prompt=prompt,
1331
+ negative_prompt=negative_prompt,
1332
+ image=image,
1333
+ num_inference_steps=num_inference_steps,
1334
+ guidance_scale=guidance_scale,
1335
+ strength=strength,
1336
+ num_images_per_prompt=num_images_per_prompt,
1337
+ eta=eta,
1338
+ generator=generator,
1339
+ prompt_embeds=prompt_embeds,
1340
+ negative_prompt_embeds=negative_prompt_embeds,
1341
+ max_embeddings_multiples=max_embeddings_multiples,
1342
+ output_type=output_type,
1343
+ return_dict=return_dict,
1344
+ callback=callback,
1345
+ is_cancelled_callback=is_cancelled_callback,
1346
+ callback_steps=callback_steps,
1347
+ cross_attention_kwargs=cross_attention_kwargs,
1348
+ )
1349
+
1350
+ def inpaint(
1351
+ self,
1352
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1353
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1354
+ prompt: Union[str, List[str]],
1355
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1356
+ strength: float = 0.8,
1357
+ num_inference_steps: Optional[int] = 50,
1358
+ guidance_scale: Optional[float] = 7.5,
1359
+ num_images_per_prompt: Optional[int] = 1,
1360
+ add_predicted_noise: Optional[bool] = False,
1361
+ eta: Optional[float] = 0.0,
1362
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1363
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1364
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1365
+ max_embeddings_multiples: Optional[int] = 3,
1366
+ output_type: Optional[str] = "pil",
1367
+ return_dict: bool = True,
1368
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1369
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1370
+ callback_steps: int = 1,
1371
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1372
+ ):
1373
+ r"""
1374
+ Function for inpaint.
1375
+ Args:
1376
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1377
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1378
+ process. This is the image whose masked region will be inpainted.
1379
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1380
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1381
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1382
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1383
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1384
+ prompt (`str` or `List[str]`):
1385
+ The prompt or prompts to guide the image generation.
1386
+ negative_prompt (`str` or `List[str]`, *optional*):
1387
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1388
+ if `guidance_scale` is less than `1`).
1389
+ strength (`float`, *optional*, defaults to 0.8):
1390
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1391
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1392
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1393
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1394
+ num_inference_steps (`int`, *optional*, defaults to 50):
1395
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1396
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1397
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1398
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1399
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1400
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1401
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1402
+ usually at the expense of lower image quality.
1403
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1404
+ The number of images to generate per prompt.
1405
+ add_predicted_noise (`bool`, *optional*, defaults to True):
1406
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
1407
+ the reverse diffusion process
1408
+ eta (`float`, *optional*, defaults to 0.0):
1409
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1410
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1411
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1412
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1413
+ to make generation deterministic.
1414
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1415
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1416
+ provided, text embeddings will be generated from `prompt` input argument.
1417
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1418
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1419
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1420
+ argument.
1421
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1422
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1423
+ output_type (`str`, *optional*, defaults to `"pil"`):
1424
+ The output format of the generate image. Choose between
1425
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1426
+ return_dict (`bool`, *optional*, defaults to `True`):
1427
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1428
+ plain tuple.
1429
+ callback (`Callable`, *optional*):
1430
+ A function that will be called every `callback_steps` steps during inference. The function will be
1431
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1432
+ is_cancelled_callback (`Callable`, *optional*):
1433
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1434
+ `True`, the inference will be cancelled.
1435
+ callback_steps (`int`, *optional*, defaults to 1):
1436
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1437
+ called at every step.
1438
+ cross_attention_kwargs (`dict`, *optional*):
1439
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1440
+ `self.processor` in
1441
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1442
+
1443
+ Returns:
1444
+ `None` if cancelled by `is_cancelled_callback`,
1445
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1446
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1447
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1448
+ (nsfw) content, according to the `safety_checker`.
1449
+ """
1450
+ return self.__call__(
1451
+ prompt=prompt,
1452
+ negative_prompt=negative_prompt,
1453
+ image=image,
1454
+ mask_image=mask_image,
1455
+ num_inference_steps=num_inference_steps,
1456
+ guidance_scale=guidance_scale,
1457
+ strength=strength,
1458
+ num_images_per_prompt=num_images_per_prompt,
1459
+ add_predicted_noise=add_predicted_noise,
1460
+ eta=eta,
1461
+ generator=generator,
1462
+ prompt_embeds=prompt_embeds,
1463
+ negative_prompt_embeds=negative_prompt_embeds,
1464
+ max_embeddings_multiples=max_embeddings_multiples,
1465
+ output_type=output_type,
1466
+ return_dict=return_dict,
1467
+ callback=callback,
1468
+ is_cancelled_callback=is_cancelled_callback,
1469
+ callback_steps=callback_steps,
1470
+ cross_attention_kwargs=cross_attention_kwargs,
1471
+ )
chat_anything/face_generator/utils/generate.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
3
+
4
+ @torch.no_grad()
5
+ def generate(pipe, prompt, negative_prompt, **generating_conf):
6
+ pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(
7
+ unet=pipe.unet,
8
+ text_encoder=pipe.text_encoder,
9
+ vae=pipe.vae,
10
+ tokenizer=pipe.tokenizer,
11
+ scheduler=pipe.scheduler,
12
+ safety_checker=None,
13
+ feature_extractor=None,
14
+ )
15
+ print('generating: ', prompt)
16
+ print('using negative prompt: ', negative_prompt)
17
+ embeds = pipe_longprompt._encode_prompt(prompt=prompt, negative_prompt=negative_prompt, device=pipe.device, num_images_per_prompt=1, do_classifier_free_guidance=generating_conf['guidance_scale']>1,)
18
+ negative_prompt_embeds, prompt_embeds = embeds.split(embeds.shape[0]//2)
19
+ pipe_out = pipe(
20
+ prompt_embeds=prompt_embeds,
21
+ negative_prompt_embeds=negative_prompt_embeds,
22
+ **generating_conf,
23
+ )
24
+ return pipe_out
25
+
26
+ if __name__ == '__main__':
27
+ from diffusers.pipelines import StableDiffusionPipeline
28
+ import argparse
29
+ def main():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument(
32
+ '--prompts',type=str,default=['starry night','Impression Sunrise, drawn by Claude Monet'], nargs='*'
33
+ )
34
+
35
+ args = parser.parse_args()
36
+ prompts = args.prompts
37
+ print(f'generating {prompts}')
38
+ model_id = 'pretrained_model/sd-v1-4'
39
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,).to('cuda')
40
+ images = pipe(prompts).images
41
+ for i, image in enumerate(images):
42
+ image.save(f'{prompts[i]}_{i}.png')
43
+
44
+ main()
45
+
chat_anything/polly_utils.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This class stores Polly voice data. Specifically, the class stores several records containing
2
+ # language, lang_code, gender, voice_id and engine. The class also has a method to return the
3
+ # voice_id, lang_code and engine given a language and gender.
4
+
5
+ NEURAL_ENGINE = "neural"
6
+ STANDARD_ENGINE = "standard"
7
+
8
+
9
+ class PollyVoiceData:
10
+ def get_voice(self, language, gender):
11
+ for voice in self.voice_data:
12
+ if voice['language'] == language and voice['gender'] == gender:
13
+ if voice['neural'] == 'Yes':
14
+ return voice['voice_id'], voice['lang_code'], NEURAL_ENGINE
15
+ for voice in self.voice_data:
16
+ if voice['language'] == language and voice['gender'] == gender:
17
+ if voice['standard'] == 'Yes':
18
+ return voice['voice_id'], voice['lang_code'], STANDARD_ENGINE
19
+ return None, None, None
20
+
21
+ def get_whisper_lang_code(self, language):
22
+ for voice in self.voice_data:
23
+ if voice['language'] == language:
24
+ return voice['whisper_lang_code']
25
+ return "en"
26
+
27
+ def __init__(self):
28
+ self.voice_data = [
29
+ {'language': 'Arabic',
30
+ 'lang_code': 'arb',
31
+ 'whisper_lang_code': 'ar',
32
+ 'voice_id': 'Zeina',
33
+ 'gender': 'Female',
34
+ 'neural': 'No',
35
+ 'standard': 'Yes'},
36
+ {'language': 'Arabic (Gulf)',
37
+ 'lang_code': 'ar-AE',
38
+ 'whisper_lang_code': 'ar',
39
+ 'voice_id': 'Hala',
40
+ 'gender': 'Female',
41
+ 'neural': 'Yes',
42
+ 'standard': 'No'},
43
+ {'language': 'Catalan',
44
+ 'lang_code': 'ca-ES',
45
+ 'whisper_lang_code': 'ca',
46
+ 'voice_id': 'Arlet',
47
+ 'gender': 'Female',
48
+ 'neural': 'Yes',
49
+ 'standard': 'No'},
50
+ {'language': 'Chinese (Cantonese)',
51
+ 'lang_code': 'yue-CN',
52
+ 'whisper_lang_code': 'zh',
53
+ 'voice_id': 'Hiujin',
54
+ 'gender': 'Female',
55
+ 'neural': 'Yes',
56
+ 'standard': 'No'},
57
+ {'language': 'Chinese (Mandarin)',
58
+ 'lang_code': 'cmn-CN',
59
+ 'whisper_lang_code': 'zh',
60
+ 'voice_id': 'Zhiyu',
61
+ 'gender': 'Female',
62
+ 'neural': 'Yes',
63
+ 'standard': 'No'},
64
+ {'language': 'Danish',
65
+ 'lang_code': 'da-DK',
66
+ 'whisper_lang_code': 'da',
67
+ 'voice_id': 'Naja',
68
+ 'gender': 'Female',
69
+ 'neural': 'No',
70
+ 'standard': 'Yes'},
71
+ {'language': 'Danish',
72
+ 'lang_code': 'da-DK',
73
+ 'whisper_lang_code': 'da',
74
+ 'voice_id': 'Mads',
75
+ 'gender': 'Male',
76
+ 'neural': 'No',
77
+ 'standard': 'Yes'},
78
+ {'language': 'Dutch',
79
+ 'lang_code': 'nl-NL',
80
+ 'whisper_lang_code': 'nl',
81
+ 'voice_id': 'Laura',
82
+ 'gender': 'Female',
83
+ 'neural': 'Yes',
84
+ 'standard': 'No'},
85
+ {'language': 'Dutch',
86
+ 'lang_code': 'nl-NL',
87
+ 'whisper_lang_code': 'nl',
88
+ 'voice_id': 'Lotte',
89
+ 'gender': 'Female',
90
+ 'neural': 'No',
91
+ 'standard': 'Yes'},
92
+ {'language': 'Dutch',
93
+ 'lang_code': 'nl-NL',
94
+ 'whisper_lang_code': 'nl',
95
+ 'voice_id': 'Ruben',
96
+ 'gender': 'Male',
97
+ 'neural': 'No',
98
+ 'standard': 'Yes'},
99
+ {'language': 'English (Australian)',
100
+ 'lang_code': 'en-AU',
101
+ 'whisper_lang_code': 'en',
102
+ 'voice_id': 'Nicole',
103
+ 'gender': 'Female',
104
+ 'neural': 'No',
105
+ 'standard': 'Yes'},
106
+ {'language': 'English (Australian)',
107
+ 'lang_code': 'en-AU',
108
+ 'whisper_lang_code': 'en',
109
+ 'voice_id': 'Olivia',
110
+ 'gender': 'Female',
111
+ 'neural': 'Yes',
112
+ 'standard': 'No'},
113
+ {'language': 'English (Australian)',
114
+ 'lang_code': 'en-AU',
115
+ 'whisper_lang_code': 'en',
116
+ 'voice_id': 'Russell',
117
+ 'gender': 'Male',
118
+ 'neural': 'No',
119
+ 'standard': 'Yes'},
120
+ {'language': 'English (British)',
121
+ 'lang_code': 'en-GB',
122
+ 'whisper_lang_code': 'en',
123
+ 'voice_id': 'Amy',
124
+ 'gender': 'Female',
125
+ 'neural': 'Yes',
126
+ 'standard': 'Yes'},
127
+ {'language': 'English (British)',
128
+ 'lang_code': 'en-GB',
129
+ 'whisper_lang_code': 'en',
130
+ 'voice_id': 'Emma',
131
+ 'gender': 'Female',
132
+ 'neural': 'Yes',
133
+ 'standard': 'Yes'},
134
+ {'language': 'English (British)',
135
+ 'lang_code': 'en-GB',
136
+ 'whisper_lang_code': 'en',
137
+ 'voice_id': 'Brian',
138
+ 'gender': 'Male',
139
+ 'neural': 'Yes',
140
+ 'standard': 'Yes'},
141
+ {'language': 'English (British)',
142
+ 'lang_code': 'en-GB',
143
+ 'whisper_lang_code': 'en',
144
+ 'voice_id': 'Arthur',
145
+ 'gender': 'Male',
146
+ 'neural': 'Yes',
147
+ 'standard': 'No'},
148
+ {'language': 'English (Indian)',
149
+ 'lang_code': 'en-IN',
150
+ 'whisper_lang_code': 'en',
151
+ 'voice_id': 'Aditi',
152
+ 'gender': 'Female',
153
+ 'neural': 'No',
154
+ 'standard': 'Yes'},
155
+ {'language': 'English (Indian)',
156
+ 'lang_code': 'en-IN',
157
+ 'whisper_lang_code': 'en',
158
+ 'voice_id': 'Raveena',
159
+ 'gender': 'Female',
160
+ 'neural': 'No',
161
+ 'standard': 'Yes'},
162
+ {'language': 'English (Indian)',
163
+ 'lang_code': 'en-IN',
164
+ 'whisper_lang_code': 'en',
165
+ 'voice_id': 'Kajal',
166
+ 'gender': 'Female',
167
+ 'neural': 'Yes',
168
+ 'standard': 'No'},
169
+ {'language': 'English (New Zealand)',
170
+ 'lang_code': 'en-NZ',
171
+ 'whisper_lang_code': 'en',
172
+ 'voice_id': 'Aria',
173
+ 'gender': 'Female',
174
+ 'neural': 'Yes',
175
+ 'standard': 'No'},
176
+ {'language': 'English (South African)',
177
+ 'lang_code': 'en-ZA',
178
+ 'whisper_lang_code': 'en',
179
+ 'voice_id': 'Ayanda',
180
+ 'gender': 'Female',
181
+ 'neural': 'Yes',
182
+ 'standard': 'No'},
183
+ {'language': 'English (US)',
184
+ 'lang_code': 'en-US',
185
+ 'whisper_lang_code': 'en',
186
+ 'voice_id': 'Ivy',
187
+ 'gender': 'Female (child)',
188
+ 'neural': 'Yes',
189
+ 'standard': 'Yes'},
190
+ {'language': 'English (US)',
191
+ 'lang_code': 'en-US',
192
+ 'whisper_lang_code': 'en',
193
+ 'voice_id': 'Joanna',
194
+ 'gender': 'Female',
195
+ 'neural': 'Yes',
196
+ 'standard': 'Yes'},
197
+ {'language': 'English (US)',
198
+ 'lang_code': 'en-US',
199
+ 'whisper_lang_code': 'en',
200
+ 'voice_id': 'Kendra',
201
+ 'gender': 'Female',
202
+ 'neural': 'Yes',
203
+ 'standard': 'Yes'},
204
+ {'language': 'English (US)',
205
+ 'lang_code': 'en-US',
206
+ 'whisper_lang_code': 'en',
207
+ 'voice_id': 'Kimberly',
208
+ 'gender': 'Female',
209
+ 'neural': 'Yes',
210
+ 'standard': 'Yes'},
211
+ {'language': 'English (US)',
212
+ 'lang_code': 'en-US',
213
+ 'whisper_lang_code': 'en',
214
+ 'voice_id': 'Salli',
215
+ 'gender': 'Female',
216
+ 'neural': 'Yes',
217
+ 'standard': 'Yes'},
218
+ {'language': 'English (US)',
219
+ 'lang_code': 'en-US',
220
+ 'whisper_lang_code': 'en',
221
+ 'voice_id': 'Joey',
222
+ 'gender': 'Male',
223
+ 'neural': 'Yes',
224
+ 'standard': 'Yes'},
225
+ {'language': 'English (US)',
226
+ 'lang_code': 'en-US',
227
+ 'whisper_lang_code': 'en',
228
+ 'voice_id': 'Justin',
229
+ 'gender': 'Male (child)',
230
+ 'neural': 'Yes',
231
+ 'standard': 'Yes'},
232
+ {'language': 'English (US)',
233
+ 'lang_code': 'en-US',
234
+ 'whisper_lang_code': 'en',
235
+ 'voice_id': 'Kevin',
236
+ 'gender': 'Male (child)',
237
+ 'neural': 'Yes',
238
+ 'standard': 'No'},
239
+ {'language': 'English (US)',
240
+ 'lang_code': 'en-US',
241
+ 'whisper_lang_code': 'en',
242
+ 'voice_id': 'Matthew',
243
+ 'gender': 'Male',
244
+ 'neural': 'Yes',
245
+ 'standard': 'Yes'},
246
+ {'language': 'English (Welsh)',
247
+ 'lang_code': 'en-GB-WLS',
248
+ 'whisper_lang_code': 'en',
249
+ 'voice_id': 'Geraint',
250
+ 'gender': 'Male',
251
+ 'neural': 'No',
252
+ 'standard': 'Yes'},
253
+ {'language': 'Finnish',
254
+ 'lang_code': 'fi-FI',
255
+ 'whisper_lang_code': 'fi',
256
+ 'voice_id': 'Suvi',
257
+ 'gender': 'Female',
258
+ 'neural': 'Yes',
259
+ 'standard': 'No'},
260
+ {'language': 'French',
261
+ 'lang_code': 'fr-FR',
262
+ 'whisper_lang_code': 'fr',
263
+ 'voice_id': 'Celine',
264
+ 'gender': 'Female',
265
+ 'neural': 'No',
266
+ 'standard': 'Yes'},
267
+ {'language': 'French',
268
+ 'lang_code': 'fr-FR',
269
+ 'whisper_lang_code': 'fr',
270
+ 'voice_id': 'Lea',
271
+ 'gender': 'Female',
272
+ 'neural': 'Yes',
273
+ 'standard': 'Yes'},
274
+ {'language': 'French',
275
+ 'lang_code': 'fr-FR',
276
+ 'whisper_lang_code': 'fr',
277
+ 'voice_id': 'Mathieu',
278
+ 'gender': 'Male',
279
+ 'neural': 'No',
280
+ 'standard': 'Yes'},
281
+ {'language': 'French (Canadian)',
282
+ 'lang_code': 'fr-CA',
283
+ 'whisper_lang_code': 'fr',
284
+ 'voice_id': 'Chantal',
285
+ 'gender': 'Female',
286
+ 'neural': 'No',
287
+ 'standard': 'Yes'},
288
+ {'language': 'French (Canadian)',
289
+ 'lang_code': 'fr-CA',
290
+ 'whisper_lang_code': 'fr',
291
+ 'voice_id': 'Gabrielle',
292
+ 'gender': 'Female',
293
+ 'neural': 'Yes',
294
+ 'standard': 'No'},
295
+ {'language': 'French (Canadian)',
296
+ 'lang_code': 'fr-CA',
297
+ 'whisper_lang_code': 'fr',
298
+ 'voice_id': 'Liam',
299
+ 'gender': 'Male',
300
+ 'neural': 'Yes',
301
+ 'standard': 'No'},
302
+ {'language': 'German',
303
+ 'lang_code': 'de-DE',
304
+ 'whisper_lang_code': 'de',
305
+ 'voice_id': 'Marlene',
306
+ 'gender': 'Female',
307
+ 'neural': 'No',
308
+ 'standard': 'Yes'},
309
+ {'language': 'German',
310
+ 'lang_code': 'de-DE',
311
+ 'whisper_lang_code': 'de',
312
+ 'voice_id': 'Vicki',
313
+ 'gender': 'Female',
314
+ 'neural': 'Yes',
315
+ 'standard': 'Yes'},
316
+ {'language': 'German',
317
+ 'lang_code': 'de-DE',
318
+ 'whisper_lang_code': 'de',
319
+ 'voice_id': 'Hans',
320
+ 'gender': 'Male',
321
+ 'neural': 'No',
322
+ 'standard': 'Yes'},
323
+ {'language': 'German',
324
+ 'lang_code': 'de-DE',
325
+ 'whisper_lang_code': 'de',
326
+ 'voice_id': 'Daniel',
327
+ 'gender': 'Male',
328
+ 'neural': 'Yes',
329
+ 'standard': 'No'},
330
+ {'language': 'German (Austrian)',
331
+ 'lang_code': 'de-AT',
332
+ 'whisper_lang_code': 'de',
333
+ 'voice_id': 'Hannah',
334
+ 'gender': 'Female',
335
+ 'neural': 'Yes',
336
+ 'standard': 'No'},
337
+ {'language': 'Hindi',
338
+ 'lang_code': 'hi-IN',
339
+ 'whisper_lang_code': 'hi',
340
+ 'voice_id': 'Aditi',
341
+ 'gender': 'Female',
342
+ 'neural': 'No',
343
+ 'standard': 'Yes'},
344
+ {'language': 'Hindi',
345
+ 'lang_code': 'hi-IN',
346
+ 'whisper_lang_code': 'hi',
347
+ 'voice_id': 'Kajal',
348
+ 'gender': 'Female',
349
+ 'neural': 'Yes',
350
+ 'standard': 'No'},
351
+ {'language': 'Icelandic',
352
+ 'lang_code': 'is-IS',
353
+ 'whisper_lang_code': 'is',
354
+ 'voice_id': 'Dora',
355
+ 'gender': 'Female',
356
+ 'neural': 'No',
357
+ 'standard': 'Yes'},
358
+ {'language': 'Icelandic',
359
+ 'lang_code': 'is-IS',
360
+ 'whisper_lang_code': 'is',
361
+ 'voice_id': 'Karl',
362
+ 'gender': 'Male',
363
+ 'neural': 'No',
364
+ 'standard': 'Yes'},
365
+ {'language': 'Italian',
366
+ 'lang_code': 'it-IT',
367
+ 'whisper_lang_code': 'it',
368
+ 'voice_id': 'Carla',
369
+ 'gender': 'Female',
370
+ 'neural': 'No',
371
+ 'standard': 'Yes'},
372
+ {'language': 'Italian',
373
+ 'lang_code': 'it-IT',
374
+ 'whisper_lang_code': 'it',
375
+ 'voice_id': 'Bianca',
376
+ 'gender': 'Female',
377
+ 'neural': 'Yes',
378
+ 'standard': 'Yes'},
379
+ {'language': 'Japanese',
380
+ 'lang_code': 'ja-JP',
381
+ 'whisper_lang_code': 'ja',
382
+ 'voice_id': 'Mizuki',
383
+ 'gender': 'Female',
384
+ 'neural': 'No',
385
+ 'standard': 'Yes'},
386
+ {'language': 'Japanese',
387
+ 'lang_code': 'ja-JP',
388
+ 'whisper_lang_code': 'ja',
389
+ 'voice_id': 'Takumi',
390
+ 'gender': 'Male',
391
+ 'neural': 'Yes',
392
+ 'standard': 'Yes'},
393
+ {'language': 'Korean',
394
+ 'lang_code': 'ko-KR',
395
+ 'whisper_lang_code': 'ko',
396
+ 'voice_id': 'Seoyeon',
397
+ 'gender': 'Female',
398
+ 'neural': 'Yes',
399
+ 'standard': 'Yes'},
400
+ {'language': 'Norwegian',
401
+ 'lang_code': 'nb-NO',
402
+ 'whisper_lang_code': 'no',
403
+ 'voice_id': 'Liv',
404
+ 'gender': 'Female',
405
+ 'neural': 'No',
406
+ 'standard': 'Yes'},
407
+ {'language': 'Norwegian',
408
+ 'lang_code': 'nb-NO',
409
+ 'whisper_lang_code': 'no',
410
+ 'voice_id': 'Ida',
411
+ 'gender': 'Female',
412
+ 'neural': 'Yes',
413
+ 'standard': 'No'},
414
+ {'language': 'Polish',
415
+ 'lang_code': 'pl-PL',
416
+ 'whisper_lang_code': 'pl',
417
+ 'voice_id': 'Ewa',
418
+ 'gender': 'Female',
419
+ 'neural': 'No',
420
+ 'standard': 'Yes'},
421
+ {'language': 'Polish',
422
+ 'lang_code': 'pl-PL',
423
+ 'whisper_lang_code': 'pl',
424
+ 'voice_id': 'Maja',
425
+ 'gender': 'Female',
426
+ 'neural': 'No',
427
+ 'standard': 'Yes'},
428
+ {'language': 'Polish',
429
+ 'lang_code': 'pl-PL',
430
+ 'whisper_lang_code': 'pl',
431
+ 'voice_id': 'Jacek',
432
+ 'gender': 'Male',
433
+ 'neural': 'No',
434
+ 'standard': 'Yes'},
435
+ {'language': 'Polish',
436
+ 'lang_code': 'pl-PL',
437
+ 'whisper_lang_code': 'pl',
438
+ 'voice_id': 'Jan',
439
+ 'gender': 'Male',
440
+ 'neural': 'No',
441
+ 'standard': 'Yes'},
442
+ {'language': 'Polish',
443
+ 'lang_code': 'pl-PL',
444
+ 'whisper_lang_code': 'pl',
445
+ 'voice_id': 'Ola',
446
+ 'gender': 'Female',
447
+ 'neural': 'Yes',
448
+ 'standard': 'No'},
449
+ {'language': 'Portuguese (Brazilian)',
450
+ 'lang_code': 'pt-BR',
451
+ 'whisper_lang_code': 'pt',
452
+ 'voice_id': 'Camila',
453
+ 'gender': 'Female',
454
+ 'neural': 'Yes',
455
+ 'standard': 'Yes'},
456
+ {'language': 'Portuguese (Brazilian)',
457
+ 'lang_code': 'pt-BR',
458
+ 'whisper_lang_code': 'pt',
459
+ 'voice_id': 'Vitoria',
460
+ 'gender': 'Female',
461
+ 'neural': 'Yes',
462
+ 'standard': 'Yes'},
463
+ {'language': 'Portuguese (Brazilian)',
464
+ 'lang_code': 'pt-BR',
465
+ 'whisper_lang_code': 'pt',
466
+ 'voice_id': 'Ricardo',
467
+ 'gender': 'Male',
468
+ 'neural': 'No',
469
+ 'standard': 'Yes'},
470
+ {'language': 'Portuguese (European)',
471
+ 'lang_code': 'pt-PT',
472
+ 'whisper_lang_code': 'pt',
473
+ 'voice_id': 'Ines',
474
+ 'gender': 'Female',
475
+ 'neural': 'Yes',
476
+ 'standard': 'Yes'},
477
+ {'language': 'Portuguese (European)',
478
+ 'lang_code': 'pt-PT',
479
+ 'whisper_lang_code': 'pt',
480
+ 'voice_id': 'Cristiano',
481
+ 'gender': 'Male',
482
+ 'neural': 'No',
483
+ 'standard': 'Yes'},
484
+ {'language': 'Romanian',
485
+ 'lang_code': 'ro-RO',
486
+ 'whisper_lang_code': 'ro',
487
+ 'voice_id': 'Carmen',
488
+ 'gender': 'Female',
489
+ 'neural': 'No',
490
+ 'standard': 'Yes'},
491
+ {'language': 'Russian',
492
+ 'lang_code': 'ru-RU',
493
+ 'whisper_lang_code': 'ru',
494
+ 'voice_id': 'Tatyana',
495
+ 'gender': 'Female',
496
+ 'neural': 'No',
497
+ 'standard': 'Yes'},
498
+ {'language': 'Russian',
499
+ 'lang_code': 'ru-RU',
500
+ 'whisper_lang_code': 'ru',
501
+ 'voice_id': 'Maxim',
502
+ 'gender': 'Male',
503
+ 'neural': 'No',
504
+ 'standard': 'Yes'},
505
+ {'language': 'Spanish (European)',
506
+ 'lang_code': 'es-ES',
507
+ 'whisper_lang_code': 'es',
508
+ 'voice_id': 'Conchita',
509
+ 'gender': 'Female',
510
+ 'neural': 'No',
511
+ 'standard': 'Yes'},
512
+ {'language': 'Spanish (European)',
513
+ 'lang_code': 'es-ES',
514
+ 'whisper_lang_code': 'es',
515
+ 'voice_id': 'Lucia',
516
+ 'gender': 'Female',
517
+ 'neural': 'Yes',
518
+ 'standard': 'Yes'},
519
+ {'language': 'Spanish (European)',
520
+ 'lang_code': 'es-ES',
521
+ 'whisper_lang_code': 'es',
522
+ 'voice_id': 'Enrique',
523
+ 'gender': 'Male',
524
+ 'neural': 'No',
525
+ 'standard': 'Yes'},
526
+ {'language': 'Spanish (Mexican)',
527
+ 'lang_code': 'es-MX',
528
+ 'whisper_lang_code': 'es',
529
+ 'voice_id': 'Mia',
530
+ 'gender': 'Female',
531
+ 'neural': 'Yes',
532
+ 'standard': 'Yes'},
533
+ {'language': 'Spanish (US)',
534
+ 'lang_code': 'es-US',
535
+ 'whisper_lang_code': 'es',
536
+ 'voice_id': 'Lupe',
537
+ 'gender': 'Female',
538
+ 'neural': 'Yes',
539
+ 'standard': 'Yes'},
540
+ {'language': 'Spanish (US)',
541
+ 'lang_code': 'es-US',
542
+ 'whisper_lang_code': 'es',
543
+ 'voice_id': 'Penelope',
544
+ 'gender': 'Female',
545
+ 'neural': 'No',
546
+ 'standard': 'Yes'},
547
+ {'language': 'Spanish (US)',
548
+ 'lang_code': 'es-US',
549
+ 'whisper_lang_code': 'es',
550
+ 'voice_id': 'Miguel',
551
+ 'gender': 'Male',
552
+ 'neural': 'No',
553
+ 'standard': 'Yes'},
554
+ {'language': 'Spanish (US)',
555
+ 'lang_code': 'es-US',
556
+ 'whisper_lang_code': 'es',
557
+ 'voice_id': 'Pedro',
558
+ 'gender': 'Male',
559
+ 'neural': 'Yes',
560
+ 'standard': 'No'},
561
+ {'language': 'Swedish',
562
+ 'lang_code': 'sv-SE',
563
+ 'whisper_lang_code': 'sv',
564
+ 'voice_id': 'Astrid',
565
+ 'gender': 'Female',
566
+ 'neural': 'No',
567
+ 'standard': 'Yes'},
568
+ {'language': 'Swedish',
569
+ 'lang_code': 'sv-SE',
570
+ 'whisper_lang_code': 'sv',
571
+ 'voice_id': 'Elin',
572
+ 'gender': 'Female',
573
+ 'neural': 'Yes',
574
+ 'standard': 'No'},
575
+ {'language': 'Turkish',
576
+ 'lang_code': 'tr-TR',
577
+ 'whisper_lang_code': 'tr',
578
+ 'voice_id': 'Filiz',
579
+ 'gender': 'Female',
580
+ 'neural': 'No',
581
+ 'standard': 'Yes'},
582
+ {'language': 'Welsh',
583
+ 'lang_code': 'cy-GB',
584
+ 'whisper_lang_code': 'cy',
585
+ 'voice_id': 'Gwyneth',
586
+ 'gender': 'Female',
587
+ 'neural': 'No',
588
+ 'standard': 'Yes'}
589
+ ]
590
+
591
+
592
+ # Run from the command-line
593
+ if __name__ == '__main__':
594
+ polly_voice_data = PollyVoiceData()
595
+
596
+ voice_id, language_code, engine = polly_voice_data.get_voice('English (US)', 'Male')
597
+ print('English (US)', 'Male', voice_id, language_code, engine)
598
+
599
+ voice_id, language_code, engine = polly_voice_data.get_voice('English (US)', 'Female')
600
+ print('English (US)', 'Female', voice_id, language_code, engine)
601
+
602
+ voice_id, language_code, engine = polly_voice_data.get_voice('French', 'Female')
603
+ print('French', 'Female', voice_id, language_code, engine)
604
+
605
+ voice_id, language_code, engine = polly_voice_data.get_voice('French', 'Male')
606
+ print('French', 'Male', voice_id, language_code, engine)
607
+
608
+ voice_id, language_code, engine = polly_voice_data.get_voice('Japanese', 'Female')
609
+ print('Japanese', 'Female', voice_id, language_code, engine)
610
+
611
+ voice_id, language_code, engine = polly_voice_data.get_voice('Japanese', 'Male')
612
+ print('Japanese', 'Male', voice_id, language_code, engine)
613
+
614
+ voice_id, language_code, engine = polly_voice_data.get_voice('Hindi', 'Female')
615
+ print('Hindi', 'Female', voice_id, language_code, engine)
616
+
617
+ voice_id, language_code, engine = polly_voice_data.get_voice('Hindi', 'Male')
618
+ print('Hindi', 'Male', voice_id, language_code, engine)
619
+
620
+ whisper_lang_code = polly_voice_data.get_whisper_lang_code('English (US)')
621
+ print('English (US) whisper_lang_code:', whisper_lang_code)
622
+
623
+ whisper_lang_code = polly_voice_data.get_whisper_lang_code('Chinese (Mandarin)')
624
+ print('Chinese (Mandarin) whisper_lang_code:', whisper_lang_code)
625
+
626
+ whisper_lang_code = polly_voice_data.get_whisper_lang_code('Norwegian')
627
+ print('Norwegian whisper_lang_code:', whisper_lang_code)
628
+
629
+ whisper_lang_code = polly_voice_data.get_whisper_lang_code('Dutch')
630
+ print('Dutch whisper_lang_code:', whisper_lang_code)
631
+
632
+ whisper_lang_code = polly_voice_data.get_whisper_lang_code('Foo')
633
+ print('Foo whisper_lang_code:', whisper_lang_code)
634
+
635
+
chat_anything/sad_talker/__init__.py ADDED
File without changes
chat_anything/sad_talker/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
chat_anything/sad_talker/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
chat_anything/sad_talker/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from chat_anything.sad_talker.audio2pose_models.cvae import CVAE
4
+ from chat_anything.sad_talker.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from chat_anything.sad_talker.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch
chat_anything/sad_talker/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ # state_dict = self.audio_encoder.state_dict()
47
+
48
+ # for k,v in wav2lip_state_dict.items():
49
+ # if 'audio_encoder' in k:
50
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ # self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
chat_anything/sad_talker/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from chat_anything.sad_talker.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
chat_anything/sad_talker/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
chat_anything/sad_talker/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ResidualConv(nn.Module):
6
+ def __init__(self, input_dim, output_dim, stride, padding):
7
+ super(ResidualConv, self).__init__()
8
+
9
+ self.conv_block = nn.Sequential(
10
+ nn.BatchNorm2d(input_dim),
11
+ nn.ReLU(),
12
+ nn.Conv2d(
13
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
+ ),
15
+ nn.BatchNorm2d(output_dim),
16
+ nn.ReLU(),
17
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
+ )
19
+ self.conv_skip = nn.Sequential(
20
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
+ nn.BatchNorm2d(output_dim),
22
+ )
23
+
24
+ def forward(self, x):
25
+
26
+ return self.conv_block(x) + self.conv_skip(x)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, input_dim, output_dim, kernel, stride):
31
+ super(Upsample, self).__init__()
32
+
33
+ self.upsample = nn.ConvTranspose2d(
34
+ input_dim, output_dim, kernel_size=kernel, stride=stride
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.upsample(x)
39
+
40
+
41
+ class Squeeze_Excite_Block(nn.Module):
42
+ def __init__(self, channel, reduction=16):
43
+ super(Squeeze_Excite_Block, self).__init__()
44
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
+ self.fc = nn.Sequential(
46
+ nn.Linear(channel, channel // reduction, bias=False),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(channel // reduction, channel, bias=False),
49
+ nn.Sigmoid(),
50
+ )
51
+
52
+ def forward(self, x):
53
+ b, c, _, _ = x.size()
54
+ y = self.avg_pool(x).view(b, c)
55
+ y = self.fc(y).view(b, c, 1, 1)
56
+ return x * y.expand_as(x)
57
+
58
+
59
+ class ASPP(nn.Module):
60
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
+ super(ASPP, self).__init__()
62
+
63
+ self.aspp_block1 = nn.Sequential(
64
+ nn.Conv2d(
65
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
+ ),
67
+ nn.ReLU(inplace=True),
68
+ nn.BatchNorm2d(out_dims),
69
+ )
70
+ self.aspp_block2 = nn.Sequential(
71
+ nn.Conv2d(
72
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
+ ),
74
+ nn.ReLU(inplace=True),
75
+ nn.BatchNorm2d(out_dims),
76
+ )
77
+ self.aspp_block3 = nn.Sequential(
78
+ nn.Conv2d(
79
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
+ ),
81
+ nn.ReLU(inplace=True),
82
+ nn.BatchNorm2d(out_dims),
83
+ )
84
+
85
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
+ self._init_weights()
87
+
88
+ def forward(self, x):
89
+ x1 = self.aspp_block1(x)
90
+ x2 = self.aspp_block2(x)
91
+ x3 = self.aspp_block3(x)
92
+ out = torch.cat([x1, x2, x3], dim=1)
93
+ return self.output(out)
94
+
95
+ def _init_weights(self):
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class Upsample_(nn.Module):
105
+ def __init__(self, scale=2):
106
+ super(Upsample_, self).__init__()
107
+
108
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
+
110
+ def forward(self, x):
111
+ return self.upsample(x)
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+ def __init__(self, input_encoder, input_decoder, output_dim):
116
+ super(AttentionBlock, self).__init__()
117
+
118
+ self.conv_encoder = nn.Sequential(
119
+ nn.BatchNorm2d(input_encoder),
120
+ nn.ReLU(),
121
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
+ nn.MaxPool2d(2, 2),
123
+ )
124
+
125
+ self.conv_decoder = nn.Sequential(
126
+ nn.BatchNorm2d(input_decoder),
127
+ nn.ReLU(),
128
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
+ )
130
+
131
+ self.conv_attn = nn.Sequential(
132
+ nn.BatchNorm2d(output_dim),
133
+ nn.ReLU(),
134
+ nn.Conv2d(output_dim, 1, 1),
135
+ )
136
+
137
+ def forward(self, x1, x2):
138
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
+ out = self.conv_attn(out)
140
+ return out * x2
chat_anything/sad_talker/audio2pose_models/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from chat_anything.sad_talker.audio2pose_models.networks import ResidualConv, Upsample
4
+
5
+
6
+ class ResUnet(nn.Module):
7
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
+ super(ResUnet, self).__init__()
9
+
10
+ self.input_layer = nn.Sequential(
11
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(filters[0]),
13
+ nn.ReLU(),
14
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
+ )
16
+ self.input_skip = nn.Sequential(
17
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
+ )
19
+
20
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
+
23
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
+
25
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
+
28
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
+
31
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
+
34
+ self.output_layer = nn.Sequential(
35
+ nn.Conv2d(filters[0], 1, 1, 1),
36
+ nn.Sigmoid(),
37
+ )
38
+
39
+ def forward(self, x):
40
+ # Encode
41
+ x1 = self.input_layer(x) + self.input_skip(x)
42
+ x2 = self.residual_conv_1(x1)
43
+ x3 = self.residual_conv_2(x2)
44
+ # Bridge
45
+ x4 = self.bridge(x3)
46
+
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
chat_anything/sad_talker/config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
chat_anything/sad_talker/config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
chat_anything/sad_talker/config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
chat_anything/sad_talker/config/facerender_still.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 73
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
chat_anything/sad_talker/config/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
chat_anything/sad_talker/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
chat_anything/sad_talker/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
chat_anything/sad_talker/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
chat_anything/sad_talker/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
chat_anything/sad_talker/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
chat_anything/sad_talker/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+ class KeypointExtractor():
15
+ def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
+ device=device)
18
+
19
+ def extract_keypoint(self, images, name=None, info=True):
20
+ if isinstance(images, list):
21
+ keypoints = []
22
+ if info:
23
+ i_range = tqdm(images,desc='landmark Det:')
24
+ else:
25
+ i_range = images
26
+
27
+ for image in i_range:
28
+ current_kp = self.extract_keypoint(image)
29
+ if np.mean(current_kp) == -1 and keypoints:
30
+ keypoints.append(keypoints[-1])
31
+ else:
32
+ keypoints.append(current_kp[None])
33
+
34
+ keypoints = np.concatenate(keypoints, 0)
35
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
+ return keypoints
37
+ else:
38
+ while True:
39
+ try:
40
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
+ break
42
+ except RuntimeError as e:
43
+ if str(e).startswith('CUDA'):
44
+ print("Warning: out of memory, sleep for 1s")
45
+ time.sleep(1)
46
+ else:
47
+ print(e)
48
+ break
49
+ except TypeError:
50
+ print('No face detected in this image')
51
+ shape = [68, 2]
52
+ keypoints = -1. * np.ones(shape)
53
+ break
54
+ if name is not None:
55
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
+ return keypoints
57
+
58
+ def read_video(filename):
59
+ frames = []
60
+ cap = cv2.VideoCapture(filename)
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if ret:
64
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ frame = Image.fromarray(frame)
66
+ frames.append(frame)
67
+ else:
68
+ break
69
+ cap.release()
70
+ return frames
71
+
72
+ def run(data):
73
+ filename, opt, device = data
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
75
+ kp_extractor = KeypointExtractor()
76
+ images = read_video(filename)
77
+ name = filename.split('/')[-2:]
78
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
+ kp_extractor.extract_keypoint(
80
+ images,
81
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
82
+ )
83
+
84
+ if __name__ == '__main__':
85
+ set_start_method('spawn')
86
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
+ parser.add_argument('--device_ids', type=str, default='0,1')
90
+ parser.add_argument('--workers', type=int, default=4)
91
+
92
+ opt = parser.parse_args()
93
+ filenames = list()
94
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
+ extensions = VIDEO_EXTENSIONS
97
+
98
+ for ext in extensions:
99
+ os.listdir(f'{opt.input_dir}')
100
+ print(f'{opt.input_dir}/*.{ext}')
101
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
+ print('Total number of videos:', len(filenames))
103
+ pool = Pool(opt.workers)
104
+ args_list = cycle([opt])
105
+ device_ids = opt.device_ids.split(",")
106
+ device_ids = cycle(device_ids)
107
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
+ None
chat_anything/sad_talker/face3d/extract_kp_videos_safe.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from torch.multiprocessing import Pool, Process, set_start_method
12
+
13
+ from facexlib.alignment import landmark_98_to_68
14
+ from facexlib.detection import init_detection_model
15
+
16
+ from facexlib.utils import load_file_from_url
17
+ from chat_anything.sad_talker.face3d.util.my_awing_arch import FAN
18
+
19
+ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
+ if model_name == 'awing_fan':
21
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
22
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
+ else:
24
+ raise NotImplementedError(f'{model_name} is not implemented.')
25
+
26
+ model_path = load_file_from_url(
27
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
+ model.eval()
30
+ model = model.to(device)
31
+ return model
32
+
33
+
34
+ class KeypointExtractor():
35
+ def __init__(self, device='cuda'):
36
+
37
+ ### gfpgan/weights
38
+ try:
39
+ import webui # in webui
40
+ root_path = 'extensions/SadTalker/gfpgan/weights'
41
+
42
+ except:
43
+ # root_path = 'gfpgan/weights'
44
+ root_path = 'MODELS/gfpgan/weights'
45
+
46
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
47
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
48
+
49
+ def extract_keypoint(self, images, name=None, info=True):
50
+ if isinstance(images, list):
51
+ keypoints = []
52
+ if info:
53
+ i_range = tqdm(images,desc='landmark Det:')
54
+ else:
55
+ i_range = images
56
+
57
+ for image in i_range:
58
+ print("detect landmarks")
59
+ current_kp = self.extract_keypoint(image)
60
+ # current_kp = self.detector.get_landmarks(np.array(image))
61
+ if np.mean(current_kp) == -1 and keypoints:
62
+ keypoints.append(keypoints[-1])
63
+ else:
64
+ keypoints.append(current_kp[None])
65
+
66
+ keypoints = np.concatenate(keypoints, 0)
67
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
68
+ return keypoints
69
+ else:
70
+ print("here")
71
+ while True:
72
+ try:
73
+ with torch.no_grad():
74
+ # face detection -> face alignment.
75
+ img = np.array(images)
76
+ bboxes = self.det_net.detect_faces(images, 0.97)
77
+
78
+ bboxes = bboxes[0]
79
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
80
+
81
+ landmarks=self.detector.get_landmarks(img)
82
+ print(landmarks.shape)
83
+ start_time=time.time()
84
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
85
+ end_time=time.time()
86
+ print(type(keypoints))
87
+ print(keypoints.shape)
88
+
89
+ elapsed_time = end_time - start_time # 计算时间差
90
+ print("landmark检测时间:%.4f秒" % elapsed_time)
91
+ #### keypoints to the original location
92
+ keypoints[:,0] += int(bboxes[0])
93
+ keypoints[:,1] += int(bboxes[1])
94
+
95
+ break
96
+ except RuntimeError as e:
97
+ if str(e).startswith('CUDA'):
98
+ print("Warning: out of memory, sleep for 1s")
99
+ time.sleep(1)
100
+ else:
101
+ print(e)
102
+ break
103
+ except TypeError:
104
+ print('No face detected in this image')
105
+ shape = [68, 2]
106
+ keypoints = -1. * np.ones(shape)
107
+ break
108
+ if name is not None:
109
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
110
+ return keypoints
111
+
112
+ def read_video(filename):
113
+ frames = []
114
+ cap = cv2.VideoCapture(filename)
115
+ while cap.isOpened():
116
+ ret, frame = cap.read()
117
+ if ret:
118
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
119
+ frame = Image.fromarray(frame)
120
+ frames.append(frame)
121
+ else:
122
+ break
123
+ cap.release()
124
+ return frames
125
+
126
+ def run(data):
127
+ filename, opt, device = data
128
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
129
+ kp_extractor = KeypointExtractor()
130
+ images = read_video(filename)
131
+ name = filename.split('/')[-2:]
132
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
133
+ kp_extractor.extract_keypoint(
134
+ images,
135
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
136
+ )
137
+
138
+ if __name__ == '__main__':
139
+ set_start_method('spawn')
140
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
141
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
142
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
143
+ parser.add_argument('--device_ids', type=str, default='0,1')
144
+ parser.add_argument('--workers', type=int, default=4)
145
+
146
+ opt = parser.parse_args()
147
+ filenames = list()
148
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
149
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
150
+ extensions = VIDEO_EXTENSIONS
151
+
152
+ for ext in extensions:
153
+ os.listdir(f'{opt.input_dir}')
154
+ print(f'{opt.input_dir}/*.{ext}')
155
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
156
+ print('Total number of videos:', len(filenames))
157
+ pool = Pool(opt.workers)
158
+ args_list = cycle([opt])
159
+ device_ids = opt.device_ids.split(",")
160
+ device_ids = cycle(device_ids)
161
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
162
+ None
chat_anything/sad_talker/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from chat_anything.sad_talker.face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
chat_anything/sad_talker/face3d/models/arcface_torch/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
+ identity on a single server.
5
+
6
+ ## Requirements
7
+
8
+ - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
+ - `pip install -r requirements.txt`.
10
+ - Download the dataset
11
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
+ .
13
+
14
+ ## How to Training
15
+
16
+ To train a model, run `train.py` with the path to the configs:
17
+
18
+ ### 1. Single node, 8 GPUs:
19
+
20
+ ```shell
21
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
+ ```
23
+
24
+ ### 2. Multiple nodes, each node 8 GPUs:
25
+
26
+ Node 0:
27
+
28
+ ```shell
29
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
+ ```
31
+
32
+ Node 1:
33
+
34
+ ```shell
35
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3.Training resnet2060 with 8 GPUs:
39
+
40
+ ```shell
41
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
+ ```
43
+
44
+ ## Model Zoo
45
+
46
+ - The models are available for non-commercial research purposes only.
47
+ - All models can be found in here.
48
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
+ - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
+
51
+ ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
+
53
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
+ As the result, we can evaluate the FAIR performance for different algorithms.
56
+
57
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
+
60
+ For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
+ Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
+ There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
+
64
+ | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
+ | :---: | :--- | :--- | :--- |:--- |:--- |
66
+ | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
+ | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
+ | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
+ | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
+ | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
+ | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
+ | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
+ | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
+ | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
+ | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
+
77
+ ### Performance on IJB-C and Verification Datasets
78
+
79
+ | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
+ | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
+ | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
+ | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
+ | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
+ | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
+ | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
+ | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
+ | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
+ | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
+ | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
+
91
+ [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
+
93
+
94
+ ## [Speed Benchmark](docs/speed_benchmark.md)
95
+
96
+ **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
+ classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
+ accuracy with several times faster training performance and smaller GPU memory.
99
+ Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
+ sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
+ sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
+ we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
+ training and mixed precision training.
104
+
105
+ ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
+
107
+ More details see
108
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
+
110
+ ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
+
112
+ `-` means training failed because of gpu memory limitations.
113
+
114
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
+ | :--- | :--- | :--- | :--- |
116
+ |125000 | 4681 | 4824 | 5004 |
117
+ |1400000 | **1672** | 3043 | 4738 |
118
+ |5500000 | **-** | **1389** | 3975 |
119
+ |8000000 | **-** | **-** | 3565 |
120
+ |16000000 | **-** | **-** | 2679 |
121
+ |29000000 | **-** | **-** | **1855** |
122
+
123
+ ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
+
125
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
+ | :--- | :--- | :--- | :--- |
127
+ |125000 | 7358 | 5306 | 4868 |
128
+ |1400000 | 32252 | 11178 | 6056 |
129
+ |5500000 | **-** | 32188 | 9854 |
130
+ |8000000 | **-** | **-** | 12310 |
131
+ |16000000 | **-** | **-** | 19950 |
132
+ |29000000 | **-** | **-** | 32324 |
133
+
134
+ ## Evaluation ICCV2021-MFR and IJB-C
135
+
136
+ More details see [eval.md](docs/eval.md) in docs.
137
+
138
+ ## Test
139
+
140
+ We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
+
142
+ - [x] torch 1.6.0
143
+ - [x] torch 1.7.1
144
+ - [x] torch 1.8.0
145
+ - [x] torch 1.9.0
146
+
147
+ ## Citation
148
+
149
+ ```
150
+ @inproceedings{deng2019arcface,
151
+ title={Arcface: Additive angular margin loss for deep face recognition},
152
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
+ pages={4690--4699},
155
+ year={2019}
156
+ }
157
+ @inproceedings{an2020partical_fc,
158
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
159
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
+ Zhang, Debing and Fu Ying},
161
+ booktitle={Arxiv 2010.05222},
162
+ year={2020}
163
+ }
164
+ ```
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+ elif name == "mbf":
21
+ fp16 = kwargs.get("fp16", False)
22
+ num_features = kwargs.get("num_features", 512)
23
+ return get_mbf(fp16=fp16, num_features=num_features)
24
+ else:
25
+ raise ValueError()
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
+ super(IResNet, self).__init__()
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+ self.dilation = 1
69
+ if replace_stride_with_dilation is None:
70
+ replace_stride_with_dilation = [False, False, False]
71
+ if len(replace_stride_with_dilation) != 3:
72
+ raise ValueError("replace_stride_with_dilation should be None "
73
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
+ self.groups = groups
75
+ self.base_width = width_per_group
76
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
+ self.prelu = nn.PReLU(self.inplanes)
79
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
+ self.layer2 = self._make_layer(block,
81
+ 128,
82
+ layers[1],
83
+ stride=2,
84
+ dilate=replace_stride_with_dilation[0])
85
+ self.layer3 = self._make_layer(block,
86
+ 256,
87
+ layers[2],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[1])
90
+ self.layer4 = self._make_layer(block,
91
+ 512,
92
+ layers[3],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[2])
95
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
97
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
+ nn.init.constant_(self.features.weight, 1.0)
100
+ self.features.weight.requires_grad = False
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, 0, 0.1)
105
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
+ nn.init.constant_(m.weight, 1)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ if zero_init_residual:
110
+ for m in self.modules():
111
+ if isinstance(m, IBasicBlock):
112
+ nn.init.constant_(m.bn2.weight, 0)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
+ downsample = None
116
+ previous_dilation = self.dilation
117
+ if dilate:
118
+ self.dilation *= stride
119
+ stride = 1
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ conv1x1(self.inplanes, planes * block.expansion, stride),
123
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
+ )
125
+ layers = []
126
+ layers.append(
127
+ block(self.inplanes, planes, stride, downsample, self.groups,
128
+ self.base_width, previous_dilation))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(
132
+ block(self.inplanes,
133
+ planes,
134
+ groups=self.groups,
135
+ base_width=self.base_width,
136
+ dilation=self.dilation))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ with torch.cuda.amp.autocast(self.fp16):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.prelu(x)
145
+ x = self.layer1(x)
146
+ x = self.layer2(x)
147
+ x = self.layer3(x)
148
+ x = self.layer4(x)
149
+ x = self.bn2(x)
150
+ x = torch.flatten(x, 1)
151
+ x = self.dropout(x)
152
+ x = self.fc(x.float() if self.fp16 else x)
153
+ x = self.features(x)
154
+ return x
155
+
156
+
157
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
+ model = IResNet(block, layers, **kwargs)
159
+ if pretrained:
160
+ raise ValueError()
161
+ return model
162
+
163
+
164
+ def iresnet18(pretrained=False, progress=True, **kwargs):
165
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
+ progress, **kwargs)
167
+
168
+
169
+ def iresnet34(pretrained=False, progress=True, **kwargs):
170
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
+ progress, **kwargs)
172
+
173
+
174
+ def iresnet50(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet100(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet200(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
+ progress, **kwargs)
187
+
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512):
90
+ super(MobileFaceNet, self).__init__()
91
+ scale = 2
92
+ self.fp16 = fp16
93
+ self.layers = nn.Sequential(
94
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
+ )
103
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
+ self.features = GDC(num_features)
105
+ self._initialize_weights()
106
+
107
+ def _initialize_weights(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.BatchNorm2d):
114
+ m.weight.data.fill_(1)
115
+ m.bias.data.zero_()
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
+ if m.bias is not None:
119
+ m.bias.data.zero_()
120
+
121
+ def forward(self, x):
122
+ with torch.cuda.amp.autocast(self.fp16):
123
+ x = self.layers(x)
124
+ x = self.conv_sep(x.float() if self.fp16 else x)
125
+ x = self.features(x)
126
+ return x
127
+
128
+
129
+ def get_mbf(fp16, num_features):
130
+ return MobileFaceNet(fp16, num_features)
chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 0.1
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
chat_anything/sad_talker/face3d/models/arcface_torch/configs/__init__.py ADDED
File without changes
chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = "ms1mv3_arcface_r50"
12
+
13
+ config.dataset = "ms1m-retinaface-t1"
14
+ config.embedding_size = 512
15
+ config.sample_rate = 1
16
+ config.fp16 = False
17
+ config.momentum = 0.9
18
+ config.weight_decay = 5e-4
19
+ config.batch_size = 128
20
+ config.lr = 0.1 # batch size is 512
21
+
22
+ if config.dataset == "emore":
23
+ config.rec = "/train_tmp/faces_emore"
24
+ config.num_classes = 85742
25
+ config.num_image = 5822653
26
+ config.num_epoch = 16
27
+ config.warmup_epoch = -1
28
+ config.decay_epoch = [8, 14, ]
29
+ config.val_targets = ["lfw", ]
30
+
31
+ elif config.dataset == "ms1m-retinaface-t1":
32
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
33
+ config.num_classes = 93431
34
+ config.num_image = 5179510
35
+ config.num_epoch = 25
36
+ config.warmup_epoch = -1
37
+ config.decay_epoch = [11, 17, 22]
38
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
+
40
+ elif config.dataset == "glint360k":
41
+ config.rec = "/train_tmp/glint360k"
42
+ config.num_classes = 360232
43
+ config.num_image = 17091657
44
+ config.num_epoch = 20
45
+ config.warmup_epoch = -1
46
+ config.decay_epoch = [8, 12, 15, 18]
47
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
+
49
+ elif config.dataset == "webface":
50
+ config.rec = "/train_tmp/faces_webface_112x112"
51
+ config.num_classes = 10572
52
+ config.num_image = "forget"
53
+ config.num_epoch = 34
54
+ config.warmup_epoch = -1
55
+ config.decay_epoch = [20, 28, 32]
56
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 0.1
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r100"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]