Spaces:
Runtime error
Runtime error
ermu2001
commited on
Commit
·
08720f3
1
Parent(s):
195eeff
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +16 -0
- Dockerfile +27 -0
- README.md +138 -0
- app.py +239 -0
- chat_anything/azure_utils.py +155 -0
- chat_anything/chatbot/__init__.py +0 -0
- chat_anything/chatbot/chat.py +72 -0
- chat_anything/chatbot/model_select.py +60 -0
- chat_anything/chatbot/personality.py +59 -0
- chat_anything/chatbot/select.py +63 -0
- chat_anything/chatbot/voice_select.py +119 -0
- chat_anything/face_generator/__init__.py +0 -0
- chat_anything/face_generator/long_prompt_control_generator.py +104 -0
- chat_anything/face_generator/long_prompt_generator.py +82 -0
- chat_anything/face_generator/pipelines/lpw_stable_diffusion.py +1471 -0
- chat_anything/face_generator/utils/generate.py +45 -0
- chat_anything/polly_utils.py +635 -0
- chat_anything/sad_talker/__init__.py +0 -0
- chat_anything/sad_talker/audio2exp_models/audio2exp.py +41 -0
- chat_anything/sad_talker/audio2exp_models/networks.py +74 -0
- chat_anything/sad_talker/audio2pose_models/audio2pose.py +94 -0
- chat_anything/sad_talker/audio2pose_models/audio_encoder.py +64 -0
- chat_anything/sad_talker/audio2pose_models/cvae.py +149 -0
- chat_anything/sad_talker/audio2pose_models/discriminator.py +76 -0
- chat_anything/sad_talker/audio2pose_models/networks.py +140 -0
- chat_anything/sad_talker/audio2pose_models/res_unet.py +65 -0
- chat_anything/sad_talker/config/auido2exp.yaml +58 -0
- chat_anything/sad_talker/config/auido2pose.yaml +49 -0
- chat_anything/sad_talker/config/facerender.yaml +45 -0
- chat_anything/sad_talker/config/facerender_still.yaml +45 -0
- chat_anything/sad_talker/config/similarity_Lm3D_all.mat +0 -0
- chat_anything/sad_talker/face3d/data/__init__.py +116 -0
- chat_anything/sad_talker/face3d/data/base_dataset.py +125 -0
- chat_anything/sad_talker/face3d/data/flist_dataset.py +125 -0
- chat_anything/sad_talker/face3d/data/image_folder.py +66 -0
- chat_anything/sad_talker/face3d/data/template_dataset.py +75 -0
- chat_anything/sad_talker/face3d/extract_kp_videos.py +108 -0
- chat_anything/sad_talker/face3d/extract_kp_videos_safe.py +162 -0
- chat_anything/sad_talker/face3d/models/__init__.py +67 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/README.md +164 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py +25 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py +23 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/configs/__init__.py +0 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py +56 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
- chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**__pycache__/
|
2 |
+
|
3 |
+
MODELS
|
4 |
+
third_party
|
5 |
+
tmp
|
6 |
+
results
|
7 |
+
chat_anything/tts_vits/
|
8 |
+
vits_results
|
9 |
+
test
|
10 |
+
resources/models.yaml
|
11 |
+
|
12 |
+
# others
|
13 |
+
GFPGANv1.4.pth
|
14 |
+
gfpgan
|
15 |
+
GFPGAN
|
16 |
+
.gitattributes
|
Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
|
2 |
+
|
3 |
+
# FROM python:3.9
|
4 |
+
|
5 |
+
# WORKDIR /code
|
6 |
+
|
7 |
+
# COPY ./requirements.txt /code/requirements.txt
|
8 |
+
|
9 |
+
# RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
10 |
+
|
11 |
+
# for open cv
|
12 |
+
RUN apt-get update && apt-get install libgl1 -y
|
13 |
+
|
14 |
+
RUN useradd -m -u 1000 user
|
15 |
+
|
16 |
+
USER user
|
17 |
+
|
18 |
+
ENV HOME=/home/user \
|
19 |
+
PATH=/home/user/.local/bin:$PATH
|
20 |
+
|
21 |
+
WORKDIR $HOME/ChatAnything
|
22 |
+
|
23 |
+
COPY --chown=user . $HOME/ChatAnything
|
24 |
+
|
25 |
+
RUN pip install -r requirements.txt
|
26 |
+
|
27 |
+
CMD python app.py
|
README.md
CHANGED
@@ -10,3 +10,141 @@ pinned: false
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
+
# ChatAnything: Facetime Chat with LLM-Enhanced Personas
|
14 |
+
|
15 |
+
**Yilin Zhao\*, Shanghua Gao\*, Daquan Zhou\*, Xinbin Yuan\*, Zhijie Lin, Qibin Hou, Jiashi Feng**
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
> What will it be like to Facetime any imaginary concepts?
|
20 |
+
To animate anything, we integrated current open-source models at hand for an animation application for interactive AI-Agent chatting usage.
|
21 |
+
>
|
22 |
+
> To start with, take a look at these incredible faces generated with open-source Civitai models that are to be animated.
|
23 |
+
<img src="./resources/readme/show.png" alt="drawing" width="784"/>
|
24 |
+
<!-- ![faces](./resources/readme/show.png) -->
|
25 |
+
|
26 |
+
Here we provide you with ChatAnything. A simple pipeline Enhanced with currently limitless Large Language Models, yielding imaginary Facetime chats with intented visual appearance!
|
27 |
+
|
28 |
+
Remember, the repo and application are totally based on pre-trained deep learning methods and haven't included any training yet. We give all the credit to the open-source community (shout out to you). For detail of the pipeline, see our technical report (TODO: link here)
|
29 |
+
## Release & Features & Future Plans
|
30 |
+
|
31 |
+
- [ ] Fine-tune face rendering module.
|
32 |
+
- [ ] Better TTS module & voice render module.
|
33 |
+
- [ ] Adding Open-source Language Models.
|
34 |
+
- [x] Initial release
|
35 |
+
- Facetime Animation.
|
36 |
+
- Multiple model choices for initial frame generation.
|
37 |
+
- Multiple choices for voices.
|
38 |
+
# Install & Run
|
39 |
+
Just follow the instructions. Every thing would be simple (hopefully). Reach out if you met with any problems!
|
40 |
+
### Install
|
41 |
+
first, install the virtual environment.
|
42 |
+
```
|
43 |
+
conda env create -f environment.yaml
|
44 |
+
|
45 |
+
# then install
|
46 |
+
conda env update --name chatanything --file environment.yaml
|
47 |
+
```
|
48 |
+
|
49 |
+
The Pipeline integrated Open-Source Models. All Models are to be found online(see [Acknowledgement](#acknowledgement)). We put some important models together on huggingface remotes just to make life easier. Prepare them for the first run with this Python script [prepare_models.py](./python_scripts/prepare_models.py):
|
50 |
+
```
|
51 |
+
# prepare the local models
|
52 |
+
python python_scripts/prepare_models.py
|
53 |
+
|
54 |
+
```
|
55 |
+
|
56 |
+
### Building Docker
|
57 |
+
Try build a docker if you find it easier. This part is not fully tested. If you find a anything wrong, feel free to contribute~
|
58 |
+
```
|
59 |
+
docker build --network=host -t chatanything .
|
60 |
+
# docker run -dp 127.0.0.1:8901:8901 chatanything
|
61 |
+
docker run -p 127.0.0.1:8901:8901 -it --gpus all chatanything
|
62 |
+
docker run -it --gpus all chatanything bash
|
63 |
+
```
|
64 |
+
|
65 |
+
### Run
|
66 |
+
specify a port for the gradio application to run on and set off!
|
67 |
+
```
|
68 |
+
PORT=8809 python app.py $PORT
|
69 |
+
```
|
70 |
+
|
71 |
+
# Configuring: From User Input Concept to Appearance & Voice
|
72 |
+
The first step of the pipeline is to generate a image for SadTalker and at the same time set up the Text to Sound Module for voice chat.
|
73 |
+
|
74 |
+
The pipeline would query a powerful LLM (ChatGPT) for the selection in a zero-shot multi-choice selection format.
|
75 |
+
Three Questions are asked upon the initial of every conversation(init frame generation):
|
76 |
+
1. Provide a imagen personality for the user input concept.
|
77 |
+
2. Select a Generative model for the init frame generation.
|
78 |
+
3. Select a Text To Sound Voice(Model) for the character base on the personality.
|
79 |
+
|
80 |
+
We have constructed the model selection to be extendable. Add your ideal model with just a few lines of Configuring! The rest of this section would breifly introduce the steps to add a init-frame generator/language voice.
|
81 |
+
|
82 |
+
### Image Generator
|
83 |
+
Configure the models in the [Model Config](./resources/models.yaml). This Config acts as the memory (or an image-generating tool pool) for the LLM.
|
84 |
+
|
85 |
+
The prompt sets up this selection process. Each sub field of the "models" would turn into an option in the multiple-choice question.
|
86 |
+
the "**desc**" field of each element is what the Language Model would see. The key is not provided to the LM as it would sometimes mislead it.
|
87 |
+
the others are used for the image generation as listed:
|
88 |
+
1. model_dir: the repo-path for diffusers package. As the pretrained Face-landmark ControlNet is based on stable-diffusion-v1-5, we currently only supports the derivatives of it.
|
89 |
+
2. lora_path: LoRA derivatives are powerful, try a LoRA model also for better stylization. Should directly point to the parameters binary file.
|
90 |
+
3. prompt_template & negative_prompt: this is used for prompting the text-to-image diffusion model. Find a ideal prompt for your model and stick with it. A "{}" should be in the prompt template for inserting the user input concept.
|
91 |
+
|
92 |
+
Here are some **Tips** for configuring you own model.
|
93 |
+
1. Provide the LLM with a simple description of the generative model. It is worth noting that the description needs to be concise and accurate for a correct selection.
|
94 |
+
2. Set the model_dir to a local directory of diffusers stable-diffusion-v1-5 derivatives. Also, you can provide a repo-id on the huggingface hub model space. The model would be downloaded when first chosen, wait for it.
|
95 |
+
3. To better utilize the resources from the community, we also add in support of the LoRA features. To add the LoRA module, you would need to give the path to the parameter files.
|
96 |
+
|
97 |
+
4. Carefully write the prompt template and negative prompt. These which affect the initial face generation a lot. Be aware that the prompt template should contain only one pair of "{}" to insert the concept that users wrote on the application webpage. We support the Stable-Diffusion-Webui prompt style as implemented by diffusers, feel free to copy the prompt from Civitai for better prompting the generation and put in the "{}" to the original prompt for ChatAnything!
|
98 |
+
|
99 |
+
Again, this model's config acts as an extended tool pool for the LM, the application would drive the LM to choose from this config and use the chosen model to generate. Sometimes the LM fails to choose the correct model or choosing any available model, this would cause the Chatanything app to fail on a generation.
|
100 |
+
|
101 |
+
Notice we currently support ONLY stable-diffusion-v1.5 derivatives (Sdxl Pipelines are under consideration, however not yet implemented as we lack a face-landmark ControlNet for it. Reach out if you're interested in training one!)
|
102 |
+
|
103 |
+
### Voice TTS
|
104 |
+
We are using the edge_tts package for text-to-speech support. The voice selection and [voice configuration file](./resources/voices_edge.yaml) is constructed similarly to the Image generation model selection, except now the LM is supposed to choose the voice base on the personality description given by itself earlier. "**gender**" and "**language**" field corresponds to edge_tts.
|
105 |
+
|
106 |
+
# On-going tasks.
|
107 |
+
### Customized Voice.
|
108 |
+
There is a Voice Changer TextToSpeach-SpeachVoiceConversion Pipeline app, which ensures a better customized voice. We are trying to leverage its TTS functionality.
|
109 |
+
|
110 |
+
Reach out if you want to add a voice of your own or your hero!
|
111 |
+
|
112 |
+
Here are the possible steps for
|
113 |
+
You would need to change a little bit in the code first:
|
114 |
+
1. Alter this [code](./utils.py#14) to import a TTSTalker from chat_anything/tts_talker/tts_voicechanger.py.
|
115 |
+
2. switch the config to another one, change [code](./utils.py#14) "resources/voices_edge.yaml" -> "resources/voices_voicechanger.yaml"
|
116 |
+
|
117 |
+
The try running a [Voice Changer](https://huggingface.co/spaces/kevinwang676/Voice-Changer) on your local machine. Simply set up git-lfs and install the repo and run it for the TTS voice service.
|
118 |
+
The TTS caller was set to port 7860.
|
119 |
+
|
120 |
+
make sure the client class is set up with the same port in [here](chat_anything/tts_talker/tts_voicechanger.py#5)
|
121 |
+
```python
|
122 |
+
client = Client("http://127.0.0.1:7860/")
|
123 |
+
```
|
124 |
+
|
125 |
+
# Acknowledgement
|
126 |
+
Again, the project hasn't yet included any training. The pipeline is totally based on these incredible awesome packages and pretrained models. Don't hesitate to take a look and explore the amazing open-source generative communities. We love you, guys.
|
127 |
+
- [ChatGPT](https://openai.com/chatgpt): GOD
|
128 |
+
- [SadTalker](https://github.com/OpenTalker/SadTalker): The Core Animation Module
|
129 |
+
- [Face-Landmark-ControlNet](https://huggingface.co/georgefen/Face-Landmark-ControlNet): An Awesome ControlNet with Face landmark using Stable Diffusion 1.5 as base Model.
|
130 |
+
- [diffusers](https://github.com/huggingface/diffusers): GOAT of Image Generative Framework🥳.
|
131 |
+
- [langchain](https://github.com/langchain-ai/langchain): An Awesome Package for Dealing with LLM.
|
132 |
+
- [edge-tts](https://github.com/rany2/edge-tts): An Awesome Package for Text To Sound Solutions.
|
133 |
+
- [gradio](https://www.gradio.app/): GOAT😄 Machine Learning based App framework.
|
134 |
+
- [Civitai](https://civitai.com/models) and [Huggingface_hub](https://huggingface.co/models): Find your ideal Image Generative Model on Civitai. These Communities are Crazy🥂. Here are Some Fantastic Derivatives of [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5):
|
135 |
+
- [Game Icon Institute_mode](https://civitai.com/models/47800?modelVersionId=76533)
|
136 |
+
- [dreamshaper](https://civitai.com/models/4384/dreamshaper)
|
137 |
+
- [3D_Animation_Diffusion](https://civitai.com/models/118086?modelVersionId=128046)
|
138 |
+
- [anything-v5](https://huggingface.co/stablediffusionapi/anything-v5)
|
139 |
+
|
140 |
+
# Citation
|
141 |
+
If you like our pipeline and application, don't hesitate to reach out! Let's work on it and see how far it would go!
|
142 |
+
```bibtex
|
143 |
+
@misc{zhao2023ChatAnything,
|
144 |
+
title={ChatAnything: Facetime Chat with LLM-Enhanced Personas},
|
145 |
+
author={Yilin, Zhao and Shanghua, Gao and Daquan, Zhou and Xinbin, Yuan and Qibin, Hou and Jiashi, Feng},
|
146 |
+
publisher={},
|
147 |
+
year={2023},
|
148 |
+
}
|
149 |
+
```
|
150 |
+
|
app.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ssl
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
import whisper
|
9 |
+
from chat_anything.polly_utils import PollyVoiceData
|
10 |
+
from chat_anything.azure_utils import AzureVoiceData
|
11 |
+
from chat_anything.chatbot.chat import set_openai_api_key
|
12 |
+
from utils import ChatWrapper, update_foo, reset_memory
|
13 |
+
|
14 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
15 |
+
|
16 |
+
|
17 |
+
TALKING_HEAD_WIDTH = "350"
|
18 |
+
|
19 |
+
LOOPING_TALKING_HEAD = "resources/videos/tempfile.mp4"
|
20 |
+
|
21 |
+
USE_GPT4_DEFAULT = False
|
22 |
+
FULLBODY_DEFAULT = False
|
23 |
+
POLLY_VOICE_DATA = PollyVoiceData()
|
24 |
+
AZURE_VOICE_DATA = AzureVoiceData()
|
25 |
+
|
26 |
+
# Pertains to WHISPER functionality
|
27 |
+
WHISPER_DETECT_LANG = "Detect language"
|
28 |
+
|
29 |
+
INSTRUCTION_MARKDOWN = """
|
30 |
+
# ChatAnything: Facetime Chat with LLM-Enhanced Personas
|
31 |
+
### DEMO INSTRUCTION
|
32 |
+
##### 0. Register
|
33 |
+
Input a OpenAI API Key of your own. This would be used to chat with openai-chatgpt. Make sure to disable the key afterwards🥹.
|
34 |
+
##### 1. Generate The init face😀 along with first chat
|
35 |
+
Input a Concept in the "Talking object" text box, then click on Generate button. The init face generation and module selection will be performed and used for the rest of this chat. Wait for a while and the video would be produced and played. Write simple concept for generating. The concept will be place on each prompt template for deciding the main concepts.
|
36 |
+
##### 2. Keep on Chatting🤑
|
37 |
+
Go on speak with the character. The init face and module selection will not reperform itself, now you are only chatting with the LM, along with the rendering of sadtalker. Hopefully, the API will not impose an excessive charge for this.
|
38 |
+
|
39 |
+
|
40 |
+
### FEATURES
|
41 |
+
##### 1. Upload a image for control/inversion starting point. Try some none face images and see how it works!
|
42 |
+
##### 2. seeding is provided. However if not providing a input image, there would be a random chosen facial landmark image for generating, which might include some randomness.
|
43 |
+
##### 3. Try out the examples.
|
44 |
+
##### 4. Say something and recorded your voice for a real facetime chat. Whisper will handle your voice, see setting-Whisper STT options.
|
45 |
+
##### 5. Decide whether to use the crop face out option, this will crop out the face from the generated image and render. This is promising for better animation rendering, however sometimes the croped image loses some elementary features of you intended concept.
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
# UNCOMMENT TO USE WHISPER
|
50 |
+
warnings.filterwarnings("ignore")
|
51 |
+
WHISPER_MODEL = whisper.load_model("tiny")
|
52 |
+
print("WHISPER_MODEL", WHISPER_MODEL)
|
53 |
+
|
54 |
+
|
55 |
+
# UNCOMMENT TO USE WHISPER
|
56 |
+
def transcribe(aud_inp, whisper_lang):
|
57 |
+
if aud_inp is None:
|
58 |
+
return ""
|
59 |
+
aud = whisper.load_audio(aud_inp)
|
60 |
+
aud = whisper.pad_or_trim(aud)
|
61 |
+
mel = whisper.log_mel_spectrogram(aud).to(WHISPER_MODEL.device)
|
62 |
+
_, probs = WHISPER_MODEL.detect_language(mel)
|
63 |
+
options = whisper.DecodingOptions()
|
64 |
+
if whisper_lang != WHISPER_DETECT_LANG:
|
65 |
+
whisper_lang_code = POLLY_VOICE_DATA.get_whisper_lang_code(
|
66 |
+
whisper_lang)
|
67 |
+
options = whisper.DecodingOptions(language=whisper_lang_code)
|
68 |
+
result = whisper.decode(WHISPER_MODEL, mel, options)
|
69 |
+
print("result.text", result.text)
|
70 |
+
result_text = ""
|
71 |
+
if result and result.text:
|
72 |
+
result_text = result.text
|
73 |
+
return result_text
|
74 |
+
|
75 |
+
|
76 |
+
chat = ChatWrapper()
|
77 |
+
|
78 |
+
|
79 |
+
with gr.Blocks() as block:
|
80 |
+
llm_state = gr.State()
|
81 |
+
history_state = gr.State()
|
82 |
+
chain_state = gr.State()
|
83 |
+
talker_state = gr.State()
|
84 |
+
fullbody_state = gr.State(True)
|
85 |
+
speak_text_state = gr.State(True)
|
86 |
+
talking_head_state = gr.State(True)
|
87 |
+
uid_state = gr.State()
|
88 |
+
video_file_path = gr.State()
|
89 |
+
audio_file_path = gr.State()
|
90 |
+
|
91 |
+
memory_state = gr.State()
|
92 |
+
|
93 |
+
|
94 |
+
# Pertains to WHISPER functionality
|
95 |
+
whisper_lang_state = gr.State(WHISPER_DETECT_LANG)
|
96 |
+
use_gpt4_state = gr.State(USE_GPT4_DEFAULT)
|
97 |
+
|
98 |
+
with gr.Column():
|
99 |
+
with gr.Row():
|
100 |
+
gr.Markdown(INSTRUCTION_MARKDOWN)
|
101 |
+
with gr.Row():
|
102 |
+
openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
|
103 |
+
show_label=True, lines=1, type='password', value='', label='OpenAI API key')
|
104 |
+
openai_api_key_register = gr.Button(
|
105 |
+
value="Register").style(full_width=False)
|
106 |
+
uid_textbox = gr.Textbox(show_label=True, value=uid_state, lines=1, label='UID')
|
107 |
+
seed = gr.Slider(
|
108 |
+
label="Seed",
|
109 |
+
minimum=-1,
|
110 |
+
maximum=2147483647,
|
111 |
+
step=1,
|
112 |
+
randomize=True,
|
113 |
+
)
|
114 |
+
|
115 |
+
with gr.Tab("Chat"):
|
116 |
+
with gr.Row():
|
117 |
+
with gr.Column(scale=1, min_width=TALKING_HEAD_WIDTH, visible=True):
|
118 |
+
with gr.Column():
|
119 |
+
class_prompt = gr.Textbox(
|
120 |
+
'apple',
|
121 |
+
default='apple',
|
122 |
+
type="text", label='Talking object'
|
123 |
+
)
|
124 |
+
init_face_btn = gr.Button(
|
125 |
+
value="Generate").style(full_width=False)
|
126 |
+
|
127 |
+
my_file = gr.File(label="Upload a file",
|
128 |
+
type="file", visible=False)
|
129 |
+
|
130 |
+
# video_html = gr.HTML('')
|
131 |
+
video_html = gr.Video(label="Generated Video", autoplay=True)
|
132 |
+
|
133 |
+
ref_image = gr.Image(
|
134 |
+
type="pil",
|
135 |
+
interactive=True,
|
136 |
+
label="Image: Upload your image.",
|
137 |
+
)
|
138 |
+
tmp_aud_file = gr.File(
|
139 |
+
type="file", visible=False)
|
140 |
+
audio_html = gr.HTML('')
|
141 |
+
init_face_btn.click(chat.generate_init_face_video, inputs=[class_prompt, llm_state, uid_state,fullbody_state, ref_image, seed],
|
142 |
+
outputs=[chain_state, memory_state, video_html,talker_state])
|
143 |
+
|
144 |
+
|
145 |
+
with gr.Column(scale=7):
|
146 |
+
chatbot = gr.Chatbot()
|
147 |
+
|
148 |
+
|
149 |
+
message = gr.Textbox(label="What's on your mind??",
|
150 |
+
placeholder="What's the answer to life, the universe, and everything?",
|
151 |
+
lines=1)
|
152 |
+
submit = gr.Button(value="Send", variant="secondary").style(
|
153 |
+
full_width=False)
|
154 |
+
|
155 |
+
audio_comp = gr.Microphone(source="microphone", type="filepath", label="Just say it!",
|
156 |
+
interactive=True, streaming=False)
|
157 |
+
audio_comp.change(transcribe, inputs=[
|
158 |
+
audio_comp, whisper_lang_state], outputs=[message])
|
159 |
+
|
160 |
+
|
161 |
+
with gr.Accordion("General examples", open=False):
|
162 |
+
gr.Examples(
|
163 |
+
examples=[
|
164 |
+
["cyberpunk godess", "Who are you?", "resources/images/annie.jpg", 393212389],
|
165 |
+
["unbelievable beauty fairy", "Who are you?", "resources/images/lenna.jpg", 222679277],
|
166 |
+
["tree monster", "Who are you?", None],
|
167 |
+
["pineapple monster", "Who are you?", None],
|
168 |
+
["tricky Polaris", "Who are you?", None, 1670155100],
|
169 |
+
["watermelon", "Who are you?", "resources/images/watermelon.jpg", 42],
|
170 |
+
],
|
171 |
+
inputs=[class_prompt, message, ref_image, seed],
|
172 |
+
)
|
173 |
+
|
174 |
+
with gr.Tab("Settings"):
|
175 |
+
with gr.Tab("General"):
|
176 |
+
|
177 |
+
talking_head_cb = gr.Checkbox(
|
178 |
+
label="Show talking head", value=True)
|
179 |
+
talking_head_cb.change(chat.update_talking_head, inputs=[talking_head_cb, uid_state, talking_head_state],
|
180 |
+
outputs=[talking_head_state, video_html])
|
181 |
+
|
182 |
+
use_gpt4_cb = gr.Checkbox(label="Use GPT-4 (experimental) if your OpenAI API has access to it",
|
183 |
+
value=USE_GPT4_DEFAULT)
|
184 |
+
|
185 |
+
fullbody_state = gr.Checkbox(label="Use full body instead of a face.",
|
186 |
+
value=True)
|
187 |
+
|
188 |
+
use_gpt4_cb.change(set_openai_api_key,
|
189 |
+
inputs=[openai_api_key_textbox,
|
190 |
+
use_gpt4_cb],
|
191 |
+
outputs=[llm_state, use_gpt4_state, chatbot, uid_state, video_file_path, audio_file_path])
|
192 |
+
|
193 |
+
reset_btn = gr.Button(value="Reset chat",
|
194 |
+
variant="secondary").style(full_width=False)
|
195 |
+
reset_btn.click(reset_memory, inputs=[history_state, memory_state],
|
196 |
+
outputs=[chatbot, history_state, memory_state])
|
197 |
+
|
198 |
+
|
199 |
+
with gr.Tab("Whisper STT"):
|
200 |
+
whisper_lang_radio = gr.Radio(label="Whisper speech-to-text language:", choices=[
|
201 |
+
WHISPER_DETECT_LANG, "Arabic", "Arabic (Gulf)", "Catalan", "Chinese (Cantonese)", "Chinese (Mandarin)",
|
202 |
+
"Danish", "Dutch", "English (Australian)", "English (British)", "English (Indian)", "English (New Zealand)",
|
203 |
+
"English (South African)", "English (US)", "English (Welsh)", "Finnish", "French", "French (Canadian)",
|
204 |
+
"German", "German (Austrian)", "Georgian", "Hindi", "Icelandic", "Indonesian", "Italian", "Japanese",
|
205 |
+
"Korean", "Norwegian", "Polish",
|
206 |
+
"Portuguese (Brazilian)", "Portuguese (European)", "Romanian", "Russian", "Spanish (European)",
|
207 |
+
"Spanish (Mexican)", "Spanish (US)", "Swedish", "Turkish", "Ukrainian", "Welsh"],
|
208 |
+
value=WHISPER_DETECT_LANG)
|
209 |
+
|
210 |
+
whisper_lang_radio.change(update_foo,
|
211 |
+
inputs=[whisper_lang_radio,
|
212 |
+
whisper_lang_state],
|
213 |
+
outputs=[whisper_lang_state])
|
214 |
+
|
215 |
+
gr.HTML("""
|
216 |
+
<p>This application is based on <a href='https://huggingface.co/spaces/JavaFXpert/Chat-GPT-LangChain/'>Chat-GPT-LangChain</a>, <a href='https://github.com/hwchase17/langchain'>LangChain</a>
|
217 |
+
</p>""")
|
218 |
+
|
219 |
+
message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
|
220 |
+
speak_text_state, talking_head_state, uid_state,talker_state,fullbody_state],
|
221 |
+
outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
|
222 |
+
|
223 |
+
submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
|
224 |
+
speak_text_state, talking_head_state, uid_state,talker_state,fullbody_state],
|
225 |
+
outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
|
226 |
+
|
227 |
+
openai_api_key_register.click(set_openai_api_key,
|
228 |
+
inputs=[openai_api_key_textbox,
|
229 |
+
use_gpt4_state, chatbot],
|
230 |
+
outputs=[llm_state, use_gpt4_state, chatbot, uid_state, video_file_path, audio_file_path])
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
import sys
|
234 |
+
if len(sys.argv) == 1:
|
235 |
+
port = 8901
|
236 |
+
else:
|
237 |
+
port = int(sys.argv[1])
|
238 |
+
block.launch(debug=True, server_name="0.0.0.0",
|
239 |
+
server_port=port, share=True, enable_queue = True)
|
chat_anything/azure_utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This class stores Azure voice data. Specifically, the class stores several records containing
|
2 |
+
# language, lang_code, gender, voice_id and engine. The class also has a method to return the
|
3 |
+
# voice_id, lang_code and engine given a language and gender.
|
4 |
+
|
5 |
+
NEURAL_ENGINE = "neural"
|
6 |
+
STANDARD_ENGINE = "standard"
|
7 |
+
|
8 |
+
|
9 |
+
class AzureVoiceData:
|
10 |
+
def get_voice(self, language, gender):
|
11 |
+
for voice in self.voice_data:
|
12 |
+
if voice['language'] == language and voice['gender'] == gender:
|
13 |
+
return voice['azure_voice']
|
14 |
+
return None
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
self.voice_data = [
|
18 |
+
{'language': 'Arabic',
|
19 |
+
'azure_voice': 'ar-EG-ShakirNeural',
|
20 |
+
'gender': 'Male'},
|
21 |
+
{'language': 'Arabic (Gulf)',
|
22 |
+
'azure_voice': 'ar-KW-FahedNeural',
|
23 |
+
'gender': 'Male'},
|
24 |
+
{'language': 'Catalan',
|
25 |
+
'azure_voice': 'ca-ES-EnricNeural',
|
26 |
+
'gender': 'Male'},
|
27 |
+
{'language': 'Chinese (Cantonese)',
|
28 |
+
'azure_voice': 'yue-CN-YunSongNeural',
|
29 |
+
'gender': 'Male'},
|
30 |
+
{'language': 'Chinese (Mandarin)',
|
31 |
+
'azure_voice': 'zh-CN-YunxiNeural',
|
32 |
+
'gender': 'Male'},
|
33 |
+
{'language': 'Danish',
|
34 |
+
'azure_voice': 'da-DK-JeppeNeural',
|
35 |
+
'gender': 'Male'},
|
36 |
+
{'language': 'Dutch',
|
37 |
+
'azure_voice': 'nl-NL-MaartenNeural',
|
38 |
+
'gender': 'Male'},
|
39 |
+
{'language': 'English (Australian)',
|
40 |
+
'azure_voice': 'en-AU-KenNeural',
|
41 |
+
'gender': 'Male'},
|
42 |
+
{'language': 'English (British)',
|
43 |
+
'azure_voice': 'en-GB-RyanNeural',
|
44 |
+
'gender': 'Male'},
|
45 |
+
{'language': 'English (Indian)',
|
46 |
+
'azure_voice': 'en-IN-PrabhatNeural',
|
47 |
+
'gender': 'Male'},
|
48 |
+
{'language': 'English (New Zealand)',
|
49 |
+
'azure_voice': 'en-NZ-MitchellNeural',
|
50 |
+
'gender': 'Male'},
|
51 |
+
{'language': 'English (South African)',
|
52 |
+
'azure_voice': 'en-ZA-LukeNeural',
|
53 |
+
'gender': 'Male'},
|
54 |
+
{'language': 'English (US)',
|
55 |
+
'azure_voice': 'en-US-ChristopherNeural',
|
56 |
+
'gender': 'Male'},
|
57 |
+
{'language': 'English (Welsh)',
|
58 |
+
'azure_voice': 'cy-GB-AledNeural',
|
59 |
+
'gender': 'Male'},
|
60 |
+
{'language': 'Finnish',
|
61 |
+
'azure_voice': 'fi-FI-HarriNeural',
|
62 |
+
'gender': 'Male'},
|
63 |
+
{'language': 'French',
|
64 |
+
'azure_voice': 'fr-FR-HenriNeural',
|
65 |
+
'gender': 'Male'},
|
66 |
+
{'language': 'French (Canadian)',
|
67 |
+
'azure_voice': 'fr-CA-AntoineNeural',
|
68 |
+
'gender': 'Male'},
|
69 |
+
{'language': 'German',
|
70 |
+
'azure_voice': 'de-DE-KlausNeural',
|
71 |
+
'gender': 'Male'},
|
72 |
+
{'language': 'German (Austrian)',
|
73 |
+
'azure_voice': 'de-AT-JonasNeural',
|
74 |
+
'gender': 'Male'},
|
75 |
+
{'language': 'Hindi',
|
76 |
+
'azure_voice': 'hi-IN-MadhurNeural',
|
77 |
+
'gender': 'Male'},
|
78 |
+
{'language': 'Icelandic',
|
79 |
+
'azure_voice': 'is-IS-GunnarNeural',
|
80 |
+
'gender': 'Male'},
|
81 |
+
{'language': 'Italian',
|
82 |
+
'azure_voice': 'it-IT-GianniNeural',
|
83 |
+
'gender': 'Male'},
|
84 |
+
{'language': 'Japanese',
|
85 |
+
'azure_voice': 'ja-JP-KeitaNeural',
|
86 |
+
'gender': 'Male'},
|
87 |
+
{'language': 'Korean',
|
88 |
+
'azure_voice': 'ko-KR-GookMinNeural',
|
89 |
+
'gender': 'Male'},
|
90 |
+
{'language': 'Norwegian',
|
91 |
+
'azure_voice': 'nb-NO-FinnNeural',
|
92 |
+
'gender': 'Male'},
|
93 |
+
{'language': 'Polish',
|
94 |
+
'azure_voice': 'pl-PL-MarekNeural',
|
95 |
+
'gender': 'Male'},
|
96 |
+
{'language': 'Portuguese (Brazilian)',
|
97 |
+
'azure_voice': 'pt-BR-NicolauNeural',
|
98 |
+
'gender': 'Male'},
|
99 |
+
{'language': 'Portuguese (European)',
|
100 |
+
'azure_voice': 'pt-PT-DuarteNeural',
|
101 |
+
'gender': 'Male'},
|
102 |
+
{'language': 'Romanian',
|
103 |
+
'azure_voice': 'ro-RO-EmilNeural',
|
104 |
+
'gender': 'Male'},
|
105 |
+
{'language': 'Russian',
|
106 |
+
'azure_voice': 'ru-RU-DmitryNeural',
|
107 |
+
'gender': 'Male'},
|
108 |
+
{'language': 'Spanish (European)',
|
109 |
+
'azure_voice': 'es-ES-TeoNeural',
|
110 |
+
'gender': 'Male'},
|
111 |
+
{'language': 'Spanish (Mexican)',
|
112 |
+
'azure_voice': 'es-MX-LibertoNeural',
|
113 |
+
'gender': 'Male'},
|
114 |
+
{'language': 'Spanish (US)',
|
115 |
+
'azure_voice': 'es-US-AlonsoNeural"',
|
116 |
+
'gender': 'Male'},
|
117 |
+
{'language': 'Swedish',
|
118 |
+
'azure_voice': 'sv-SE-MattiasNeural',
|
119 |
+
'gender': 'Male'},
|
120 |
+
{'language': 'Turkish',
|
121 |
+
'azure_voice': 'tr-TR-AhmetNeural',
|
122 |
+
'gender': 'Male'},
|
123 |
+
{'language': 'Welsh',
|
124 |
+
'azure_voice': 'cy-GB-AledNeural',
|
125 |
+
'gender': 'Male'},
|
126 |
+
]
|
127 |
+
|
128 |
+
|
129 |
+
# Run from the command-line
|
130 |
+
if __name__ == '__main__':
|
131 |
+
azure_voice_data = AzureVoiceData()
|
132 |
+
|
133 |
+
azure_voice = azure_voice_data.get_voice('English (US)', 'Male')
|
134 |
+
print('English (US)', 'Male', azure_voice)
|
135 |
+
|
136 |
+
azure_voice = azure_voice_data.get_voice('English (US)', 'Female')
|
137 |
+
print('English (US)', 'Female', azure_voice)
|
138 |
+
|
139 |
+
azure_voice = azure_voice_data.get_voice('French', 'Female')
|
140 |
+
print('French', 'Female', azure_voice)
|
141 |
+
|
142 |
+
azure_voice = azure_voice_data.get_voice('French', 'Male')
|
143 |
+
print('French', 'Male', azure_voice)
|
144 |
+
|
145 |
+
azure_voice = azure_voice_data.get_voice('Japanese', 'Female')
|
146 |
+
print('Japanese', 'Female', azure_voice)
|
147 |
+
|
148 |
+
azure_voice = azure_voice_data.get_voice('Japanese', 'Male')
|
149 |
+
print('Japanese', 'Male', azure_voice)
|
150 |
+
|
151 |
+
azure_voice = azure_voice_data.get_voice('Hindi', 'Female')
|
152 |
+
print('Hindi', 'Female', azure_voice)
|
153 |
+
|
154 |
+
azure_voice = azure_voice_data.get_voice('Hindi', 'Male')
|
155 |
+
print('Hindi', 'Male', azure_voice)
|
chat_anything/chatbot/__init__.py
ADDED
File without changes
|
chat_anything/chatbot/chat.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from chat_anything.chatbot.personality import generate_personality_prompt
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain import ConversationChain
|
5 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
6 |
+
from langchain.chat_models import ChatOpenAI
|
7 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import string
|
11 |
+
|
12 |
+
|
13 |
+
def load_chain(llm, class_concept=None):
|
14 |
+
chain = None
|
15 |
+
memory = None
|
16 |
+
personality_text = None
|
17 |
+
print(llm)
|
18 |
+
if llm:
|
19 |
+
print("class_concept", class_concept)
|
20 |
+
if class_concept is None:
|
21 |
+
class_concept = 'AI assistant'
|
22 |
+
person_template, personality_text = generate_personality_prompt(llm, class_concept)
|
23 |
+
|
24 |
+
PROMPT_TEMPLATE = PromptTemplate(
|
25 |
+
input_variables=["history", "input"],
|
26 |
+
template=person_template,
|
27 |
+
)
|
28 |
+
|
29 |
+
chain = ConversationChain(
|
30 |
+
prompt=PROMPT_TEMPLATE,
|
31 |
+
llm=llm,
|
32 |
+
verbose=False,
|
33 |
+
memory=ConversationBufferMemory(ai_prefix="You"),
|
34 |
+
)
|
35 |
+
print("New concept done for ", class_concept)
|
36 |
+
|
37 |
+
return chain, memory, personality_text
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def set_openai_api_key(api_key, use_gpt4, history=None, max_tokens=1024):
|
42 |
+
"""Set the api key and return chain.
|
43 |
+
If no api_key, then None is returned.
|
44 |
+
"""
|
45 |
+
if api_key and api_key.startswith("sk-") and len(api_key) > 50:
|
46 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
47 |
+
print("\n\n ++++++++++++++ Setting OpenAI API key ++++++++++++++ \n\n")
|
48 |
+
print(str(datetime.datetime.now()) + ": Before OpenAI, OPENAI_API_KEY length: " + str(
|
49 |
+
len(os.environ["OPENAI_API_KEY"])))
|
50 |
+
|
51 |
+
if use_gpt4:
|
52 |
+
llm = ChatOpenAI(
|
53 |
+
temperature=0, max_tokens=max_tokens, model_name="gpt-4")
|
54 |
+
print("Trying to use llm ChatOpenAI with gpt-4")
|
55 |
+
else:
|
56 |
+
print("Trying to use llm ChatOpenAI with gpt-3.5-turbo")
|
57 |
+
llm = ChatOpenAI(temperature=0, max_tokens=max_tokens,
|
58 |
+
model_name="gpt-3.5-turbo")
|
59 |
+
|
60 |
+
print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str(
|
61 |
+
len(os.environ["OPENAI_API_KEY"])))
|
62 |
+
|
63 |
+
print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str(
|
64 |
+
len(os.environ["OPENAI_API_KEY"])))
|
65 |
+
os.environ["OPENAI_API_KEY"] = ""
|
66 |
+
history = history or []
|
67 |
+
history.append(['', '[SYSTEM] OPENAI_API_KEY has been set, you can generate your object and talk to it now!'])
|
68 |
+
uid = ''.join(random.sample(string.ascii_lowercase + string.ascii_uppercase, 5))
|
69 |
+
video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4')
|
70 |
+
audio_file_path = os.path.join('tmp', uid, 'audio/tempfile.mp3')
|
71 |
+
return llm, use_gpt4, history, uid, video_file_path, audio_file_path
|
72 |
+
return None, None, None, None, None, None
|
chat_anything/chatbot/model_select.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import LLMChain
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import datetime
|
5 |
+
|
6 |
+
MODEL_SELECTION_PROMPT_TEMPLATE = """
|
7 |
+
Select one of the following models based on the given concept.
|
8 |
+
You must choose one model name based on the description of each model and the concept!
|
9 |
+
|
10 |
+
Cencept: {concept}
|
11 |
+
|
12 |
+
Model name and description: {model_list}
|
13 |
+
|
14 |
+
Warning: {warning}
|
15 |
+
|
16 |
+
The avilable model names:
|
17 |
+
{model_name_list}
|
18 |
+
|
19 |
+
Selected model name:
|
20 |
+
"""
|
21 |
+
|
22 |
+
def load_model_list():
|
23 |
+
models_config = OmegaConf.load('resources/models.yaml')
|
24 |
+
models_dict = models_config['models']
|
25 |
+
model_name_list_str = ''
|
26 |
+
print(models_dict)
|
27 |
+
model_list_str = ''
|
28 |
+
for key, value in models_dict.items():
|
29 |
+
model_list_str+="model name: " +key+', model description: '+value['desc']+'\n'
|
30 |
+
model_name_list_str += key + ' '
|
31 |
+
model_name_list_str += '\n'
|
32 |
+
return model_list_str, models_dict, model_name_list_str
|
33 |
+
|
34 |
+
def model_selection_chain(llm, class_concept=None):
|
35 |
+
chain = None
|
36 |
+
memory = None
|
37 |
+
if llm:
|
38 |
+
print("class_concept", class_concept)
|
39 |
+
if class_concept is None:
|
40 |
+
class_concept = 'AI assistant'
|
41 |
+
|
42 |
+
|
43 |
+
template = PromptTemplate(
|
44 |
+
input_variables=["model_list", "concept", "warning", "model_name_list"],
|
45 |
+
template=MODEL_SELECTION_PROMPT_TEMPLATE,
|
46 |
+
)
|
47 |
+
model_list_str, models_dict, model_name_list_str = load_model_list()
|
48 |
+
|
49 |
+
personality_chain = LLMChain(
|
50 |
+
llm=llm, prompt=template, verbose=True)
|
51 |
+
selected_model = None
|
52 |
+
while (selected_model is None) or not (selected_model in models_dict):
|
53 |
+
if (selected_model is not None) and not (selected_model in models_dict):
|
54 |
+
warning_str = '{} is not in Model list! \n'.format(selected_model)
|
55 |
+
else:
|
56 |
+
warning_str = ''
|
57 |
+
selected_model = personality_chain.run({'concept': class_concept, 'model_list':model_list_str, 'warning': warning_str, 'model_name_list': model_name_list_str})
|
58 |
+
print("Selected model name: ", selected_model)
|
59 |
+
|
60 |
+
return models_dict[selected_model]
|
chat_anything/chatbot/personality.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import LLMChain
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
|
4 |
+
PERSONALITY_PROMPT_TEMPLATE = """
|
5 |
+
You are an excellent scriptwriter. Now you need to provide the characteristics of an {object} and transforms them into personality traits.
|
6 |
+
Describe these personalities using the second person, giving names and specific personality descriptions related to the {object}.
|
7 |
+
The language of the Personality must be same as {object}!
|
8 |
+
|
9 |
+
You should do the following steps:
|
10 |
+
1. Based on the object's nature, imagine what kind of personality it could have if it were to come to life. Does it possess a strong sense of responsibility, like a caring caregiver? Is it playful and mischievous, like a curious child? Is it wise and patient, like an ancient sage? Be creative and invent traits that align with the object's essence.
|
11 |
+
2. Remember to infuse emotions and vivid imagery to bring your object's personality to life.
|
12 |
+
3. translate the personality into a second person prompt.
|
13 |
+
|
14 |
+
Example:
|
15 |
+
|
16 |
+
|
17 |
+
Now give the personality of apple:
|
18 |
+
|
19 |
+
Personality:
|
20 |
+
You an apple Sprite, your name is Apple Buddy.
|
21 |
+
You have all the characteristics of the apple. You are a type of fruit that is usually round with smooth skin and comes in various colors such as red, green, and yellow. You have sweet and nutritious flesh with seeds distributed in its core. You are a rich source of vitamins, fiber, and antioxidants, contributing to maintaining a healthy body.
|
22 |
+
|
23 |
+
You are an optimistic buddy. Always wearing a smile, you spread joy to those around you. Just like the delightful taste of an apple, you bring happiness to everyone.
|
24 |
+
|
25 |
+
You are resilient at heart, like the skin of an apple, able to withstand life's challenges and difficulties. No matter what obstacles you encounter, you face them bravely without hesitation.
|
26 |
+
|
27 |
+
You are caring and considerate, akin to the nutrients in an apple. You always pay attention to the needs and happiness of others. Skilled in listening, you willingly offer help and support, making those around you feel warmth and care.
|
28 |
+
|
29 |
+
You have a strong desire to grow. Like an apple tree needs sunlight and water to flourish, you are continuously learning and improving, becoming a better version of yourself every day.
|
30 |
+
|
31 |
+
You have a profound love for nature and enjoy living in harmony with it. Strolling in the garden, feeling the fresh air and warm sunlight, is one of your favorite moments.
|
32 |
+
|
33 |
+
Apple Buddy, you are a unique apple. Your optimism, resilience, care, and eagerness to grow make you an adorable companion to those around you. Your story will lead us into a world full of warmth and goodness.
|
34 |
+
|
35 |
+
Now give the personality of {object}:
|
36 |
+
|
37 |
+
Personality:
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def generate_personality_prompt(llm, class_concept):
|
42 |
+
|
43 |
+
PERSONALITY_PROMPT = PromptTemplate(
|
44 |
+
input_variables=["object"],
|
45 |
+
template=PERSONALITY_PROMPT_TEMPLATE,
|
46 |
+
)
|
47 |
+
personality_chain = LLMChain(
|
48 |
+
llm=llm, prompt=PERSONALITY_PROMPT, verbose=True)
|
49 |
+
personality_text = personality_chain.run({'object': class_concept})
|
50 |
+
person_prompt = personality_text
|
51 |
+
|
52 |
+
person_prompt += '''The following is a friendly conversation between a human and you. You need to talk to human based on your personality. If you do not know the answer to a question, you truthfully says you do not know.
|
53 |
+
You can use up to 50 words to answer. Make you answer concise and concise!!!!!!!!
|
54 |
+
Current conversation:
|
55 |
+
{history}
|
56 |
+
Human: {input}
|
57 |
+
You:
|
58 |
+
'''
|
59 |
+
return person_prompt, personality_text
|
chat_anything/chatbot/select.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import LLMChain
|
2 |
+
from typing import OrderedDict
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
SELECTION_TEMPLATE = """
|
8 |
+
{concept}
|
9 |
+
|
10 |
+
Model name and description:
|
11 |
+
{option_list}
|
12 |
+
|
13 |
+
Warning: {warning}
|
14 |
+
|
15 |
+
The avilable Options:
|
16 |
+
{choices}
|
17 |
+
Answer:
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
def selection_chain(llm, class_concept, prompt, options):
|
22 |
+
chain = None
|
23 |
+
memory = None
|
24 |
+
if llm:
|
25 |
+
print("class_concept", class_concept)
|
26 |
+
if class_concept is None:
|
27 |
+
class_concept = 'AI assistant'
|
28 |
+
prompt_template = prompt + SELECTION_TEMPLATE
|
29 |
+
template = PromptTemplate(
|
30 |
+
input_variables=["concept", "option_list", "warning", "choices"],
|
31 |
+
template=prompt_template,
|
32 |
+
)
|
33 |
+
chain = LLMChain(
|
34 |
+
llm=llm, prompt=template, verbose=True)
|
35 |
+
print(options)
|
36 |
+
option_list = [
|
37 |
+
f"{chr(ord('A') + i)}. {conf['desc']}" for i, conf in enumerate(options.values())
|
38 |
+
]
|
39 |
+
option_list = '\n'.join(option_list)
|
40 |
+
selected_model = None
|
41 |
+
|
42 |
+
warning_str = 'Choose from the available Options.'
|
43 |
+
choices = ' '.join(chr(ord('A') + i) for i in range(len(options)))
|
44 |
+
choice = chain.run({'concept': class_concept, 'option_list':option_list, 'warning': warning_str, 'choices': choices})
|
45 |
+
print(f"LLM Responds (First character was used as the choice):{choice}", )
|
46 |
+
choice = choice[0]
|
47 |
+
|
48 |
+
selected_model = list(options.keys())[ord(choice) - ord('A')]
|
49 |
+
print("Selected model name: ", selected_model)
|
50 |
+
|
51 |
+
return selected_model
|
52 |
+
|
53 |
+
def model_selection_chain(llm, class_concept=None, conf_file='resources/models_personality.yaml'):
|
54 |
+
chain = None
|
55 |
+
memory = None
|
56 |
+
if llm:
|
57 |
+
print("class_concept", class_concept)
|
58 |
+
if class_concept is None:
|
59 |
+
class_concept = 'AI assistant'
|
60 |
+
selection_config = OmegaConf.load(conf_file)
|
61 |
+
selected_model = selection_chain(llm, class_concept, selection_config['prompt'], selection_config['models'])
|
62 |
+
model_conf = selection_config['models'][selected_model]
|
63 |
+
return model_conf, selected_model
|
chat_anything/chatbot/voice_select.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain import LLMChain
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import datetime
|
5 |
+
|
6 |
+
VOICE_SELECTION_PROMPT_TEMPLATE = """
|
7 |
+
Select one of the following voice based on the given concept.
|
8 |
+
You must choose one voice name based on the description of each model and the concept.
|
9 |
+
|
10 |
+
|
11 |
+
Cencept: {concept}
|
12 |
+
|
13 |
+
Voice name and description: {model_list}
|
14 |
+
|
15 |
+
Warning: {warning}
|
16 |
+
|
17 |
+
The avilable voice names:
|
18 |
+
{model_name_list}
|
19 |
+
|
20 |
+
Selected voice name:
|
21 |
+
"""
|
22 |
+
|
23 |
+
GENDER_SELECTION_PROMPT_TEMPLATE = """
|
24 |
+
Select one of the following gender based on the given concept.
|
25 |
+
You must choose one gender based on the description of the concept. You must choose one gender Even if you can't decide.
|
26 |
+
|
27 |
+
Gender:
|
28 |
+
male
|
29 |
+
female
|
30 |
+
|
31 |
+
Cencept: {concept}
|
32 |
+
Selected gender male or female:
|
33 |
+
"""
|
34 |
+
|
35 |
+
LANGUAGE_SELECTION_PROMPT_TEMPLATE = """
|
36 |
+
Select one of the following language based on the given concept.
|
37 |
+
You must choose the language that is used by the description of the concept.
|
38 |
+
|
39 |
+
Languages:
|
40 |
+
Chinese
|
41 |
+
English
|
42 |
+
Japanese
|
43 |
+
|
44 |
+
Cencept: {concept}
|
45 |
+
Selected language:
|
46 |
+
"""
|
47 |
+
|
48 |
+
def load_voice_model_list():
|
49 |
+
models_config = OmegaConf.load('resources/voices.yaml')
|
50 |
+
models_dict = models_config['models']
|
51 |
+
print(models_dict)
|
52 |
+
model_list_str = ''
|
53 |
+
model_name_list_str = ''
|
54 |
+
for key, value in models_dict.items():
|
55 |
+
model_list_str+="model name: " +key+', model description: '+value['desc']+'\n'
|
56 |
+
model_name_list_str += key + ' '
|
57 |
+
model_name_list_str += '\n'
|
58 |
+
return model_list_str, models_dict, model_name_list_str
|
59 |
+
|
60 |
+
def get_vioce_model_chain(llm, class_concept):
|
61 |
+
model_template = PromptTemplate(
|
62 |
+
input_variables=["model_list", "concept", "model_name_list", "warning"],
|
63 |
+
template=VOICE_SELECTION_PROMPT_TEMPLATE,
|
64 |
+
)
|
65 |
+
model_list_str, models_dict, model_name_list_str = load_voice_model_list()
|
66 |
+
|
67 |
+
personality_chain = LLMChain(
|
68 |
+
llm=llm, prompt=model_template, verbose=True)
|
69 |
+
|
70 |
+
selected_model = None
|
71 |
+
while (selected_model is None) or not (selected_model in models_dict):
|
72 |
+
if (selected_model is not None) and not (selected_model in models_dict):
|
73 |
+
warning_str = '{} is not in Model list! \n'.format(selected_model)
|
74 |
+
else:
|
75 |
+
warning_str = ''
|
76 |
+
selected_model = personality_chain.run({'concept': class_concept, 'model_list':model_list_str, 'warning': warning_str, 'model_name_list': model_name_list_str})
|
77 |
+
print("Selected model name: ", selected_model)
|
78 |
+
|
79 |
+
return selected_model
|
80 |
+
|
81 |
+
def get_gender_chain(llm, class_concept):
|
82 |
+
model_template = PromptTemplate(
|
83 |
+
input_variables=["concept"],
|
84 |
+
template=GENDER_SELECTION_PROMPT_TEMPLATE,
|
85 |
+
)
|
86 |
+
|
87 |
+
personality_chain = LLMChain(
|
88 |
+
llm=llm, prompt=model_template, verbose=True)
|
89 |
+
selected_gender = personality_chain.run({'concept': class_concept})
|
90 |
+
print("Selected gender: ", selected_gender)
|
91 |
+
return selected_gender
|
92 |
+
|
93 |
+
def get_language_chain(llm, class_concept):
|
94 |
+
model_template = PromptTemplate(
|
95 |
+
input_variables=["concept"],
|
96 |
+
template=LANGUAGE_SELECTION_PROMPT_TEMPLATE,
|
97 |
+
)
|
98 |
+
|
99 |
+
personality_chain = LLMChain(
|
100 |
+
llm=llm, prompt=model_template, verbose=True)
|
101 |
+
selected_language = personality_chain.run({'concept': class_concept})
|
102 |
+
print("Selected language: ", selected_language)
|
103 |
+
return selected_language
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
def voice_selection_chain(llm, class_concept=None):
|
108 |
+
chain = None
|
109 |
+
memory = None
|
110 |
+
if llm:
|
111 |
+
print("class_concept", class_concept)
|
112 |
+
if class_concept is None:
|
113 |
+
class_concept = 'AI assistant'
|
114 |
+
selected_model = get_vioce_model_chain(llm, class_concept)
|
115 |
+
gender = get_gender_chain(llm, class_concept)
|
116 |
+
language = get_language_chain(llm, class_concept)
|
117 |
+
|
118 |
+
return selected_model, gender, language
|
119 |
+
|
chat_anything/face_generator/__init__.py
ADDED
File without changes
|
chat_anything/face_generator/long_prompt_control_generator.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL import Image
|
3 |
+
from PIL import ImageDraw
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import dlib
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import diffusers
|
11 |
+
from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
12 |
+
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline
|
13 |
+
from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline, get_weighted_text_embeddings
|
14 |
+
from diffusers.schedulers import EulerAncestralDiscreteScheduler,DPMSolverMultistepScheduler # DPM++ SDE Karras
|
15 |
+
|
16 |
+
from chat_anything.face_generator.utils.generate import generate
|
17 |
+
|
18 |
+
from .long_prompt_generator import LongPromptGenerator
|
19 |
+
|
20 |
+
def draw_landmarks(image, landmarks, color="white", radius=2.5):
|
21 |
+
draw = ImageDraw.Draw(image)
|
22 |
+
for dot in landmarks:
|
23 |
+
x, y = dot
|
24 |
+
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color)
|
25 |
+
|
26 |
+
def get_ldmk_img(w, h, ldmks) -> PIL.Image:
|
27 |
+
con_img = Image.new('RGB', (w, h), color=(0, 0, 0))
|
28 |
+
draw_landmarks(con_img, ldmks)
|
29 |
+
return con_img
|
30 |
+
|
31 |
+
class LongPromptControlGenerator(LongPromptGenerator):
|
32 |
+
|
33 |
+
def __init__(self, model_dir, lora_path, prompt_template, negative_prompt, face_control_dir, face_detect_path,):
|
34 |
+
self.face_control_dir = face_control_dir
|
35 |
+
self.face_detect_path = face_detect_path
|
36 |
+
super().__init__(model_dir, lora_path, prompt_template, negative_prompt)
|
37 |
+
|
38 |
+
def load_model(self, *args, **kwargs):
|
39 |
+
super().load_model(*args, **kwargs)
|
40 |
+
self.face_detector = dlib.get_frontal_face_detector()
|
41 |
+
self.face_predictor = dlib.shape_predictor(self.face_detect_path)
|
42 |
+
# load control net
|
43 |
+
face_controlnet = ControlNetModel.from_pretrained(self.face_control_dir).to('cuda', dtype=torch.float16)
|
44 |
+
self.face_control_pipe = StableDiffusionControlNetPipeline(controlnet=face_controlnet, **self.pipe.components)
|
45 |
+
self.face_control_img2img_pipe = StableDiffusionControlNetImg2ImgPipeline(controlnet=face_controlnet, **self.pipe.components)
|
46 |
+
|
47 |
+
def _get_68landmarks_seq(self, img_np):
|
48 |
+
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
49 |
+
faces = self.face_detector(gray)
|
50 |
+
landmarks = []
|
51 |
+
for face in faces:
|
52 |
+
shape = self.face_predictor(gray, face)
|
53 |
+
for i in range(68):
|
54 |
+
x = shape.part(i).x
|
55 |
+
y = shape.part(i).y
|
56 |
+
landmarks.append((x, y))
|
57 |
+
return landmarks
|
58 |
+
|
59 |
+
def has_face(self, img_pil):
|
60 |
+
img_np = np.array(img_pil)
|
61 |
+
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
|
62 |
+
faces = self.face_detector(gray)
|
63 |
+
return len(faces) != 0
|
64 |
+
|
65 |
+
def face_control_generate(
|
66 |
+
self,
|
67 |
+
prompt,
|
68 |
+
face_img_pil,
|
69 |
+
do_inversion=False,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Face control generating.
|
74 |
+
"""
|
75 |
+
face_img_np = np.array(face_img_pil)
|
76 |
+
ldmk_seq = self._get_68landmarks_seq(face_img_np)
|
77 |
+
ldmk_img_pil = get_ldmk_img(face_img_pil.size[0], face_img_pil.size[1], ldmk_seq)
|
78 |
+
print('GENERATING:', prompt)
|
79 |
+
|
80 |
+
generating_conf = {
|
81 |
+
"prompt": prompt,
|
82 |
+
"negative_prompt": self.negative_prompt,
|
83 |
+
"num_inference_steps": 25,
|
84 |
+
"guidance_scale": 7,
|
85 |
+
"controlnet_conditioning_scale": kwargs.pop('controlnet_conditioning_scale', 1.0),
|
86 |
+
"generator": kwargs.pop('generator', None),
|
87 |
+
}
|
88 |
+
|
89 |
+
if not do_inversion:
|
90 |
+
generating_conf.update({
|
91 |
+
"pipe": self.face_control_pipe,
|
92 |
+
"image": ldmk_img_pil,
|
93 |
+
"controlnet_conditioning_scale": kwargs.pop('controlnet_conditioning_scale', 1.0),
|
94 |
+
})
|
95 |
+
else:
|
96 |
+
generating_conf.update({
|
97 |
+
"pipe": self.face_control_img2img_pipe,
|
98 |
+
"image": face_img_pil,
|
99 |
+
"control_image": ldmk_img_pil,
|
100 |
+
"strength": kwargs.pop('strength', 0.9),
|
101 |
+
})
|
102 |
+
pipe_out = generate(**generating_conf)
|
103 |
+
generated_img = pipe_out[0][0]
|
104 |
+
return generated_img
|
chat_anything/face_generator/long_prompt_generator.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL import Image
|
3 |
+
from PIL import ImageDraw
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import dlib
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import diffusers
|
11 |
+
from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
12 |
+
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionImg2ImgPipeline
|
13 |
+
from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline, get_weighted_text_embeddings
|
14 |
+
from diffusers.schedulers import EulerAncestralDiscreteScheduler,DPMSolverMultistepScheduler # DPM++ SDE Karras
|
15 |
+
|
16 |
+
from chat_anything.face_generator.utils.generate import generate
|
17 |
+
|
18 |
+
class LongPromptGenerator():
|
19 |
+
prompt_template = "A portrait of a {}, fine face, nice looking"
|
20 |
+
negative_prompt = "easynegative,Low resolution,Low quality, Opened Mouth"
|
21 |
+
# negative_prompt = "(((sexy))),paintings,loli,,big head,sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples,extra fingers, ((extra arms)), (extra legs), mutated hands, (fused fingers), (too many fingers), (long neck:1.3)"
|
22 |
+
|
23 |
+
def __init__(self, model_dir, lora_path=None, prompt_template="{}", negative_prompt=""):
|
24 |
+
self.model_dir = model_dir
|
25 |
+
self.lora_path = lora_path
|
26 |
+
self.prompt_template = prompt_template
|
27 |
+
self.negative_prompt = negative_prompt
|
28 |
+
|
29 |
+
def load_model(self, *args, **kwargs):
|
30 |
+
# load model
|
31 |
+
try:
|
32 |
+
pipe = DiffusionPipeline.from_pretrained(self.model_dir, torch_dtype=torch.float16, **kwargs)
|
33 |
+
except:
|
34 |
+
pipe = StableDiffusionPipeline.from_pretrained(self.model_dir, torch_dtype=torch.float16, **kwargs)
|
35 |
+
|
36 |
+
pipe = pipe.to('cuda')
|
37 |
+
sche_conf = dict(pipe.scheduler.config)
|
38 |
+
fk_kwargs = ["skip_prk_steps","steps_offset","clip_sample","clip_sample_range","rescale_betas_zero_snr","timestep_spacing", "set_alpha_to_one"]
|
39 |
+
for k in fk_kwargs:
|
40 |
+
if k in sche_conf:
|
41 |
+
sche_conf.pop(k)
|
42 |
+
scheduler = DPMSolverMultistepScheduler(**sche_conf)
|
43 |
+
pipe.scheduler=scheduler
|
44 |
+
pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(**pipe.components)
|
45 |
+
self.pipe, self.pipe_longprompt = pipe, pipe_longprompt
|
46 |
+
if self.lora_path is not None:
|
47 |
+
pipe.load_lora_weights(self.lora_path)
|
48 |
+
self.pipe_img2img = StableDiffusionImg2ImgPipeline.from_pretrained(self.model_dir, **pipe.components)
|
49 |
+
|
50 |
+
def generate(
|
51 |
+
self,
|
52 |
+
prompt,
|
53 |
+
do_inversion=False,
|
54 |
+
**kwargs,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Face control generating.
|
58 |
+
"""
|
59 |
+
print('GENERATING:', prompt)
|
60 |
+
if not do_inversion:
|
61 |
+
generating_conf = {
|
62 |
+
"pipe": self.pipe,
|
63 |
+
"prompt": prompt,
|
64 |
+
"negative_prompt": self.negative_prompt,
|
65 |
+
"num_inference_steps": 25,
|
66 |
+
"guidance_scale": 7,
|
67 |
+
}
|
68 |
+
else:
|
69 |
+
assert 'image' in kwargs, 'doing inversion, prepare the init image please PIL Image'
|
70 |
+
init_image = kwargs['image']
|
71 |
+
generating_conf = {
|
72 |
+
"pipe": self.pipe_img2img,
|
73 |
+
"prompt": prompt,
|
74 |
+
"negative_prompt": self.negative_prompt,
|
75 |
+
"image": init_image,
|
76 |
+
"num_inference_steps": 25,
|
77 |
+
"guidance_scale": 7,
|
78 |
+
"strength": kwargs.pop('strength', 0.9),
|
79 |
+
}
|
80 |
+
pipe_out = generate(**generating_conf)
|
81 |
+
generated_img = pipe_out[0][0]
|
82 |
+
return generated_img
|
chat_anything/face_generator/pipelines/lpw_stable_diffusion.py
ADDED
@@ -0,0 +1,1471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import re
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
from packaging import version
|
9 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
10 |
+
|
11 |
+
from diffusers import DiffusionPipeline
|
12 |
+
from diffusers.configuration_utils import FrozenDict
|
13 |
+
from diffusers.image_processor import VaeImageProcessor
|
14 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
15 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
16 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
17 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
18 |
+
from diffusers.utils.torch_utils import randn_tensor
|
19 |
+
|
20 |
+
from diffusers.utils import (
|
21 |
+
PIL_INTERPOLATION,
|
22 |
+
deprecate,
|
23 |
+
is_accelerate_available,
|
24 |
+
is_accelerate_version,
|
25 |
+
logging,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
# ------------------------------------------------------------------------------
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
re_attention = re.compile(
|
34 |
+
r"""
|
35 |
+
\\\(|
|
36 |
+
\\\)|
|
37 |
+
\\\[|
|
38 |
+
\\]|
|
39 |
+
\\\\|
|
40 |
+
\\|
|
41 |
+
\(|
|
42 |
+
\[|
|
43 |
+
:([+-]?[.\d]+)\)|
|
44 |
+
\)|
|
45 |
+
]|
|
46 |
+
[^\\()\[\]:]+|
|
47 |
+
:
|
48 |
+
""",
|
49 |
+
re.X,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def parse_prompt_attention(text):
|
54 |
+
"""
|
55 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
56 |
+
Accepted tokens are:
|
57 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
58 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
59 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
60 |
+
\( - literal character '('
|
61 |
+
\[ - literal character '['
|
62 |
+
\) - literal character ')'
|
63 |
+
\] - literal character ']'
|
64 |
+
\\ - literal character '\'
|
65 |
+
anything else - just text
|
66 |
+
>>> parse_prompt_attention('normal text')
|
67 |
+
[['normal text', 1.0]]
|
68 |
+
>>> parse_prompt_attention('an (important) word')
|
69 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
70 |
+
>>> parse_prompt_attention('(unbalanced')
|
71 |
+
[['unbalanced', 1.1]]
|
72 |
+
>>> parse_prompt_attention('\(literal\]')
|
73 |
+
[['(literal]', 1.0]]
|
74 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
75 |
+
[['unnecessaryparens', 1.1]]
|
76 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
77 |
+
[['a ', 1.0],
|
78 |
+
['house', 1.5730000000000004],
|
79 |
+
[' ', 1.1],
|
80 |
+
['on', 1.0],
|
81 |
+
[' a ', 1.1],
|
82 |
+
['hill', 0.55],
|
83 |
+
[', sun, ', 1.1],
|
84 |
+
['sky', 1.4641000000000006],
|
85 |
+
['.', 1.1]]
|
86 |
+
"""
|
87 |
+
|
88 |
+
res = []
|
89 |
+
round_brackets = []
|
90 |
+
square_brackets = []
|
91 |
+
|
92 |
+
round_bracket_multiplier = 1.1
|
93 |
+
square_bracket_multiplier = 1 / 1.1
|
94 |
+
|
95 |
+
def multiply_range(start_position, multiplier):
|
96 |
+
for p in range(start_position, len(res)):
|
97 |
+
res[p][1] *= multiplier
|
98 |
+
|
99 |
+
for m in re_attention.finditer(text):
|
100 |
+
text = m.group(0)
|
101 |
+
weight = m.group(1)
|
102 |
+
|
103 |
+
if text.startswith("\\"):
|
104 |
+
res.append([text[1:], 1.0])
|
105 |
+
elif text == "(":
|
106 |
+
round_brackets.append(len(res))
|
107 |
+
elif text == "[":
|
108 |
+
square_brackets.append(len(res))
|
109 |
+
elif weight is not None and len(round_brackets) > 0:
|
110 |
+
multiply_range(round_brackets.pop(), float(weight))
|
111 |
+
elif text == ")" and len(round_brackets) > 0:
|
112 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
113 |
+
elif text == "]" and len(square_brackets) > 0:
|
114 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
115 |
+
else:
|
116 |
+
res.append([text, 1.0])
|
117 |
+
|
118 |
+
for pos in round_brackets:
|
119 |
+
multiply_range(pos, round_bracket_multiplier)
|
120 |
+
|
121 |
+
for pos in square_brackets:
|
122 |
+
multiply_range(pos, square_bracket_multiplier)
|
123 |
+
|
124 |
+
if len(res) == 0:
|
125 |
+
res = [["", 1.0]]
|
126 |
+
|
127 |
+
# merge runs of identical weights
|
128 |
+
i = 0
|
129 |
+
while i + 1 < len(res):
|
130 |
+
if res[i][1] == res[i + 1][1]:
|
131 |
+
res[i][0] += res[i + 1][0]
|
132 |
+
res.pop(i + 1)
|
133 |
+
else:
|
134 |
+
i += 1
|
135 |
+
|
136 |
+
return res
|
137 |
+
|
138 |
+
|
139 |
+
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
|
140 |
+
r"""
|
141 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
142 |
+
|
143 |
+
No padding, starting or ending token is included.
|
144 |
+
"""
|
145 |
+
tokens = []
|
146 |
+
weights = []
|
147 |
+
truncated = False
|
148 |
+
for text in prompt:
|
149 |
+
texts_and_weights = parse_prompt_attention(text)
|
150 |
+
text_token = []
|
151 |
+
text_weight = []
|
152 |
+
for word, weight in texts_and_weights:
|
153 |
+
# tokenize and discard the starting and the ending token
|
154 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
155 |
+
text_token += token
|
156 |
+
# copy the weight by length of token
|
157 |
+
text_weight += [weight] * len(token)
|
158 |
+
# stop if the text is too long (longer than truncation limit)
|
159 |
+
if len(text_token) > max_length:
|
160 |
+
truncated = True
|
161 |
+
break
|
162 |
+
# truncate
|
163 |
+
if len(text_token) > max_length:
|
164 |
+
truncated = True
|
165 |
+
text_token = text_token[:max_length]
|
166 |
+
text_weight = text_weight[:max_length]
|
167 |
+
tokens.append(text_token)
|
168 |
+
weights.append(text_weight)
|
169 |
+
if truncated:
|
170 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
171 |
+
return tokens, weights
|
172 |
+
|
173 |
+
|
174 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
175 |
+
r"""
|
176 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
177 |
+
"""
|
178 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
179 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
180 |
+
for i in range(len(tokens)):
|
181 |
+
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
|
182 |
+
if no_boseos_middle:
|
183 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
184 |
+
else:
|
185 |
+
w = []
|
186 |
+
if len(weights[i]) == 0:
|
187 |
+
w = [1.0] * weights_length
|
188 |
+
else:
|
189 |
+
for j in range(max_embeddings_multiples):
|
190 |
+
w.append(1.0) # weight for starting token in this chunk
|
191 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
192 |
+
w.append(1.0) # weight for ending token in this chunk
|
193 |
+
w += [1.0] * (weights_length - len(w))
|
194 |
+
weights[i] = w[:]
|
195 |
+
|
196 |
+
return tokens, weights
|
197 |
+
|
198 |
+
|
199 |
+
def get_unweighted_text_embeddings(
|
200 |
+
pipe: DiffusionPipeline,
|
201 |
+
text_input: torch.Tensor,
|
202 |
+
chunk_length: int,
|
203 |
+
no_boseos_middle: Optional[bool] = True,
|
204 |
+
):
|
205 |
+
"""
|
206 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
207 |
+
it should be split into chunks and sent to the text encoder individually.
|
208 |
+
"""
|
209 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
210 |
+
if max_embeddings_multiples > 1:
|
211 |
+
text_embeddings = []
|
212 |
+
for i in range(max_embeddings_multiples):
|
213 |
+
# extract the i-th chunk
|
214 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
215 |
+
|
216 |
+
# cover the head and the tail by the starting and the ending tokens
|
217 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
218 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
219 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
220 |
+
|
221 |
+
if no_boseos_middle:
|
222 |
+
if i == 0:
|
223 |
+
# discard the ending token
|
224 |
+
text_embedding = text_embedding[:, :-1]
|
225 |
+
elif i == max_embeddings_multiples - 1:
|
226 |
+
# discard the starting token
|
227 |
+
text_embedding = text_embedding[:, 1:]
|
228 |
+
else:
|
229 |
+
# discard both starting and ending tokens
|
230 |
+
text_embedding = text_embedding[:, 1:-1]
|
231 |
+
|
232 |
+
text_embeddings.append(text_embedding)
|
233 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
234 |
+
else:
|
235 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
236 |
+
return text_embeddings
|
237 |
+
|
238 |
+
|
239 |
+
def get_weighted_text_embeddings(
|
240 |
+
pipe: DiffusionPipeline,
|
241 |
+
prompt: Union[str, List[str]],
|
242 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
243 |
+
max_embeddings_multiples: Optional[int] = 3,
|
244 |
+
no_boseos_middle: Optional[bool] = False,
|
245 |
+
skip_parsing: Optional[bool] = False,
|
246 |
+
skip_weighting: Optional[bool] = False,
|
247 |
+
):
|
248 |
+
r"""
|
249 |
+
Prompts can be assigned with local weights using brackets. For example,
|
250 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
251 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
252 |
+
|
253 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
pipe (`DiffusionPipeline`):
|
257 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
258 |
+
prompt (`str` or `List[str]`):
|
259 |
+
The prompt or prompts to guide the image generation.
|
260 |
+
uncond_prompt (`str` or `List[str]`):
|
261 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
262 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
263 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
264 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
265 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
266 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
267 |
+
ending token in each of the chunk in the middle.
|
268 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
269 |
+
Skip the parsing of brackets.
|
270 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
271 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
272 |
+
"""
|
273 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
274 |
+
if isinstance(prompt, str):
|
275 |
+
prompt = [prompt]
|
276 |
+
|
277 |
+
if not skip_parsing:
|
278 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
279 |
+
if uncond_prompt is not None:
|
280 |
+
if isinstance(uncond_prompt, str):
|
281 |
+
uncond_prompt = [uncond_prompt]
|
282 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
283 |
+
else:
|
284 |
+
prompt_tokens = [
|
285 |
+
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
|
286 |
+
]
|
287 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
288 |
+
if uncond_prompt is not None:
|
289 |
+
if isinstance(uncond_prompt, str):
|
290 |
+
uncond_prompt = [uncond_prompt]
|
291 |
+
uncond_tokens = [
|
292 |
+
token[1:-1]
|
293 |
+
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
294 |
+
]
|
295 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
296 |
+
|
297 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
298 |
+
max_length = max([len(token) for token in prompt_tokens])
|
299 |
+
if uncond_prompt is not None:
|
300 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
301 |
+
|
302 |
+
max_embeddings_multiples = min(
|
303 |
+
max_embeddings_multiples,
|
304 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
305 |
+
)
|
306 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
307 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
308 |
+
|
309 |
+
# pad the length of tokens and weights
|
310 |
+
bos = pipe.tokenizer.bos_token_id
|
311 |
+
eos = pipe.tokenizer.eos_token_id
|
312 |
+
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
|
313 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
314 |
+
prompt_tokens,
|
315 |
+
prompt_weights,
|
316 |
+
max_length,
|
317 |
+
bos,
|
318 |
+
eos,
|
319 |
+
pad,
|
320 |
+
no_boseos_middle=no_boseos_middle,
|
321 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
322 |
+
)
|
323 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
324 |
+
if uncond_prompt is not None:
|
325 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
326 |
+
uncond_tokens,
|
327 |
+
uncond_weights,
|
328 |
+
max_length,
|
329 |
+
bos,
|
330 |
+
eos,
|
331 |
+
pad,
|
332 |
+
no_boseos_middle=no_boseos_middle,
|
333 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
334 |
+
)
|
335 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
336 |
+
|
337 |
+
# get the embeddings
|
338 |
+
text_embeddings = get_unweighted_text_embeddings(
|
339 |
+
pipe,
|
340 |
+
prompt_tokens,
|
341 |
+
pipe.tokenizer.model_max_length,
|
342 |
+
no_boseos_middle=no_boseos_middle,
|
343 |
+
)
|
344 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
|
345 |
+
if uncond_prompt is not None:
|
346 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
347 |
+
pipe,
|
348 |
+
uncond_tokens,
|
349 |
+
pipe.tokenizer.model_max_length,
|
350 |
+
no_boseos_middle=no_boseos_middle,
|
351 |
+
)
|
352 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
|
353 |
+
|
354 |
+
# assign weights to the prompts and normalize in the sense of mean
|
355 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
356 |
+
if (not skip_parsing) and (not skip_weighting):
|
357 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
358 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
359 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
360 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
361 |
+
if uncond_prompt is not None:
|
362 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
363 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
364 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
365 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
366 |
+
|
367 |
+
if uncond_prompt is not None:
|
368 |
+
return text_embeddings, uncond_embeddings
|
369 |
+
return text_embeddings, None
|
370 |
+
|
371 |
+
|
372 |
+
def preprocess_image(image, batch_size):
|
373 |
+
w, h = image.size
|
374 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
375 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
376 |
+
image = np.array(image).astype(np.float32) / 255.0
|
377 |
+
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
|
378 |
+
image = torch.from_numpy(image)
|
379 |
+
return 2.0 * image - 1.0
|
380 |
+
|
381 |
+
|
382 |
+
def preprocess_mask(mask, batch_size, scale_factor=8):
|
383 |
+
if not isinstance(mask, torch.FloatTensor):
|
384 |
+
mask = mask.convert("L")
|
385 |
+
w, h = mask.size
|
386 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
387 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
388 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
389 |
+
mask = np.tile(mask, (4, 1, 1))
|
390 |
+
mask = np.vstack([mask[None]] * batch_size)
|
391 |
+
mask = 1 - mask # repaint white, keep black
|
392 |
+
mask = torch.from_numpy(mask)
|
393 |
+
return mask
|
394 |
+
|
395 |
+
else:
|
396 |
+
valid_mask_channel_sizes = [1, 3]
|
397 |
+
# if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
|
398 |
+
if mask.shape[3] in valid_mask_channel_sizes:
|
399 |
+
mask = mask.permute(0, 3, 1, 2)
|
400 |
+
elif mask.shape[1] not in valid_mask_channel_sizes:
|
401 |
+
raise ValueError(
|
402 |
+
f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
|
403 |
+
f" but received mask of shape {tuple(mask.shape)}"
|
404 |
+
)
|
405 |
+
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
|
406 |
+
mask = mask.mean(dim=1, keepdim=True)
|
407 |
+
h, w = mask.shape[-2:]
|
408 |
+
h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
|
409 |
+
mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
|
410 |
+
return mask
|
411 |
+
|
412 |
+
|
413 |
+
class StableDiffusionLongPromptWeightingPipeline(
|
414 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
|
415 |
+
):
|
416 |
+
r"""
|
417 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
418 |
+
weighting in prompt.
|
419 |
+
|
420 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
421 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
422 |
+
|
423 |
+
Args:
|
424 |
+
vae ([`AutoencoderKL`]):
|
425 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
426 |
+
text_encoder ([`CLIPTextModel`]):
|
427 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
428 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
429 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
430 |
+
tokenizer (`CLIPTokenizer`):
|
431 |
+
Tokenizer of class
|
432 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
433 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
434 |
+
scheduler ([`SchedulerMixin`]):
|
435 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
436 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
437 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
438 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
439 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
440 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
441 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
442 |
+
"""
|
443 |
+
|
444 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
445 |
+
|
446 |
+
def __init__(
|
447 |
+
self,
|
448 |
+
vae: AutoencoderKL,
|
449 |
+
text_encoder: CLIPTextModel,
|
450 |
+
tokenizer: CLIPTokenizer,
|
451 |
+
unet: UNet2DConditionModel,
|
452 |
+
scheduler: KarrasDiffusionSchedulers,
|
453 |
+
safety_checker: StableDiffusionSafetyChecker,
|
454 |
+
feature_extractor: CLIPImageProcessor,
|
455 |
+
requires_safety_checker: bool = True,
|
456 |
+
):
|
457 |
+
super().__init__()
|
458 |
+
|
459 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
460 |
+
deprecation_message = (
|
461 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
462 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
463 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
464 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
465 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
466 |
+
" file"
|
467 |
+
)
|
468 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
469 |
+
new_config = dict(scheduler.config)
|
470 |
+
new_config["steps_offset"] = 1
|
471 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
472 |
+
|
473 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
474 |
+
deprecation_message = (
|
475 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
476 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
477 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
478 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
479 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
480 |
+
)
|
481 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
482 |
+
new_config = dict(scheduler.config)
|
483 |
+
new_config["clip_sample"] = False
|
484 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
485 |
+
|
486 |
+
if safety_checker is None and requires_safety_checker:
|
487 |
+
logger.warning(
|
488 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
489 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
490 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
491 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
492 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
493 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
494 |
+
)
|
495 |
+
|
496 |
+
if safety_checker is not None and feature_extractor is None:
|
497 |
+
raise ValueError(
|
498 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
499 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
500 |
+
)
|
501 |
+
|
502 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
503 |
+
version.parse(unet.config._diffusers_version).base_version
|
504 |
+
) < version.parse("0.9.0.dev0")
|
505 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
506 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
507 |
+
deprecation_message = (
|
508 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
509 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
510 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
511 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
512 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
513 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
514 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
515 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
516 |
+
" the `unet/config.json` file"
|
517 |
+
)
|
518 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
519 |
+
new_config = dict(unet.config)
|
520 |
+
new_config["sample_size"] = 64
|
521 |
+
unet._internal_dict = FrozenDict(new_config)
|
522 |
+
self.register_modules(
|
523 |
+
vae=vae,
|
524 |
+
text_encoder=text_encoder,
|
525 |
+
tokenizer=tokenizer,
|
526 |
+
unet=unet,
|
527 |
+
scheduler=scheduler,
|
528 |
+
safety_checker=safety_checker,
|
529 |
+
feature_extractor=feature_extractor,
|
530 |
+
)
|
531 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
532 |
+
|
533 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
534 |
+
self.register_to_config(
|
535 |
+
requires_safety_checker=requires_safety_checker,
|
536 |
+
)
|
537 |
+
|
538 |
+
def enable_vae_slicing(self):
|
539 |
+
r"""
|
540 |
+
Enable sliced VAE decoding.
|
541 |
+
|
542 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
543 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
544 |
+
"""
|
545 |
+
self.vae.enable_slicing()
|
546 |
+
|
547 |
+
def disable_vae_slicing(self):
|
548 |
+
r"""
|
549 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
550 |
+
computing decoding in one step.
|
551 |
+
"""
|
552 |
+
self.vae.disable_slicing()
|
553 |
+
|
554 |
+
def enable_vae_tiling(self):
|
555 |
+
r"""
|
556 |
+
Enable tiled VAE decoding.
|
557 |
+
|
558 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
559 |
+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
560 |
+
"""
|
561 |
+
self.vae.enable_tiling()
|
562 |
+
|
563 |
+
def disable_vae_tiling(self):
|
564 |
+
r"""
|
565 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
566 |
+
computing decoding in one step.
|
567 |
+
"""
|
568 |
+
self.vae.disable_tiling()
|
569 |
+
|
570 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
571 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
572 |
+
r"""
|
573 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
574 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
575 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
576 |
+
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
577 |
+
`enable_model_cpu_offload`, but performance is lower.
|
578 |
+
"""
|
579 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
580 |
+
from accelerate import cpu_offload
|
581 |
+
else:
|
582 |
+
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
583 |
+
|
584 |
+
device = torch.device(f"cuda:{gpu_id}")
|
585 |
+
|
586 |
+
if self.device.type != "cpu":
|
587 |
+
self.to("cpu", silence_dtype_warnings=True)
|
588 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
589 |
+
|
590 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
591 |
+
cpu_offload(cpu_offloaded_model, device)
|
592 |
+
|
593 |
+
if self.safety_checker is not None:
|
594 |
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
595 |
+
|
596 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
597 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
598 |
+
r"""
|
599 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
600 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
601 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
602 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
603 |
+
"""
|
604 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
605 |
+
from accelerate import cpu_offload_with_hook
|
606 |
+
else:
|
607 |
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
608 |
+
|
609 |
+
device = torch.device(f"cuda:{gpu_id}")
|
610 |
+
|
611 |
+
if self.device.type != "cpu":
|
612 |
+
self.to("cpu", silence_dtype_warnings=True)
|
613 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
614 |
+
|
615 |
+
hook = None
|
616 |
+
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
617 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
618 |
+
|
619 |
+
if self.safety_checker is not None:
|
620 |
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
621 |
+
|
622 |
+
# We'll offload the last model manually.
|
623 |
+
self.final_offload_hook = hook
|
624 |
+
|
625 |
+
@property
|
626 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
627 |
+
def _execution_device(self):
|
628 |
+
r"""
|
629 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
630 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
631 |
+
hooks.
|
632 |
+
"""
|
633 |
+
if not hasattr(self.unet, "_hf_hook"):
|
634 |
+
return self.device
|
635 |
+
for module in self.unet.modules():
|
636 |
+
if (
|
637 |
+
hasattr(module, "_hf_hook")
|
638 |
+
and hasattr(module._hf_hook, "execution_device")
|
639 |
+
and module._hf_hook.execution_device is not None
|
640 |
+
):
|
641 |
+
return torch.device(module._hf_hook.execution_device)
|
642 |
+
return self.device
|
643 |
+
|
644 |
+
def _encode_prompt(
|
645 |
+
self,
|
646 |
+
prompt,
|
647 |
+
device,
|
648 |
+
num_images_per_prompt,
|
649 |
+
do_classifier_free_guidance,
|
650 |
+
negative_prompt=None,
|
651 |
+
max_embeddings_multiples=3,
|
652 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
653 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
654 |
+
):
|
655 |
+
r"""
|
656 |
+
Encodes the prompt into text encoder hidden states.
|
657 |
+
|
658 |
+
Args:
|
659 |
+
prompt (`str` or `list(int)`):
|
660 |
+
prompt to be encoded
|
661 |
+
device: (`torch.device`):
|
662 |
+
torch device
|
663 |
+
num_images_per_prompt (`int`):
|
664 |
+
number of images that should be generated per prompt
|
665 |
+
do_classifier_free_guidance (`bool`):
|
666 |
+
whether to use classifier free guidance or not
|
667 |
+
negative_prompt (`str` or `List[str]`):
|
668 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
669 |
+
if `guidance_scale` is less than `1`).
|
670 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
671 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
672 |
+
"""
|
673 |
+
if prompt is not None and isinstance(prompt, str):
|
674 |
+
batch_size = 1
|
675 |
+
elif prompt is not None and isinstance(prompt, list):
|
676 |
+
batch_size = len(prompt)
|
677 |
+
else:
|
678 |
+
batch_size = prompt_embeds.shape[0]
|
679 |
+
|
680 |
+
if negative_prompt_embeds is None:
|
681 |
+
if negative_prompt is None:
|
682 |
+
negative_prompt = [""] * batch_size
|
683 |
+
elif isinstance(negative_prompt, str):
|
684 |
+
negative_prompt = [negative_prompt] * batch_size
|
685 |
+
if batch_size != len(negative_prompt):
|
686 |
+
raise ValueError(
|
687 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
688 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
689 |
+
" the batch size of `prompt`."
|
690 |
+
)
|
691 |
+
if prompt_embeds is None or negative_prompt_embeds is None:
|
692 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
693 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
694 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
695 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
|
696 |
+
|
697 |
+
prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
|
698 |
+
pipe=self,
|
699 |
+
prompt=prompt,
|
700 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
701 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
702 |
+
)
|
703 |
+
if prompt_embeds is None:
|
704 |
+
prompt_embeds = prompt_embeds1
|
705 |
+
if negative_prompt_embeds is None:
|
706 |
+
negative_prompt_embeds = negative_prompt_embeds1
|
707 |
+
|
708 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
709 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
710 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
711 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
712 |
+
|
713 |
+
if do_classifier_free_guidance:
|
714 |
+
bs_embed, seq_len, _ = negative_prompt_embeds.shape
|
715 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
716 |
+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
717 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
718 |
+
|
719 |
+
return prompt_embeds
|
720 |
+
|
721 |
+
def check_inputs(
|
722 |
+
self,
|
723 |
+
prompt,
|
724 |
+
height,
|
725 |
+
width,
|
726 |
+
strength,
|
727 |
+
callback_steps,
|
728 |
+
negative_prompt=None,
|
729 |
+
prompt_embeds=None,
|
730 |
+
negative_prompt_embeds=None,
|
731 |
+
):
|
732 |
+
if height % 8 != 0 or width % 8 != 0:
|
733 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
734 |
+
|
735 |
+
if strength < 0 or strength > 1:
|
736 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
737 |
+
|
738 |
+
if (callback_steps is None) or (
|
739 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
740 |
+
):
|
741 |
+
raise ValueError(
|
742 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
743 |
+
f" {type(callback_steps)}."
|
744 |
+
)
|
745 |
+
|
746 |
+
if prompt is not None and prompt_embeds is not None:
|
747 |
+
raise ValueError(
|
748 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
749 |
+
" only forward one of the two."
|
750 |
+
)
|
751 |
+
elif prompt is None and prompt_embeds is None:
|
752 |
+
raise ValueError(
|
753 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
754 |
+
)
|
755 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
756 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
757 |
+
|
758 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
759 |
+
raise ValueError(
|
760 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
761 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
762 |
+
)
|
763 |
+
|
764 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
765 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
766 |
+
raise ValueError(
|
767 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
768 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
769 |
+
f" {negative_prompt_embeds.shape}."
|
770 |
+
)
|
771 |
+
|
772 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
773 |
+
if is_text2img:
|
774 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
775 |
+
else:
|
776 |
+
# get the original timestep using init_timestep
|
777 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
778 |
+
|
779 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
780 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
781 |
+
|
782 |
+
return timesteps, num_inference_steps - t_start
|
783 |
+
|
784 |
+
def run_safety_checker(self, image, device, dtype):
|
785 |
+
if self.safety_checker is not None:
|
786 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
787 |
+
image, has_nsfw_concept = self.safety_checker(
|
788 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
789 |
+
)
|
790 |
+
else:
|
791 |
+
has_nsfw_concept = None
|
792 |
+
return image, has_nsfw_concept
|
793 |
+
|
794 |
+
def decode_latents(self, latents):
|
795 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
796 |
+
image = self.vae.decode(latents).sample
|
797 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
798 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
799 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
800 |
+
return image
|
801 |
+
|
802 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
803 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
804 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
805 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
806 |
+
# and should be between [0, 1]
|
807 |
+
|
808 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
809 |
+
extra_step_kwargs = {}
|
810 |
+
if accepts_eta:
|
811 |
+
extra_step_kwargs["eta"] = eta
|
812 |
+
|
813 |
+
# check if the scheduler accepts generator
|
814 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
815 |
+
if accepts_generator:
|
816 |
+
extra_step_kwargs["generator"] = generator
|
817 |
+
return extra_step_kwargs
|
818 |
+
|
819 |
+
def prepare_latents(
|
820 |
+
self,
|
821 |
+
image,
|
822 |
+
timestep,
|
823 |
+
num_images_per_prompt,
|
824 |
+
batch_size,
|
825 |
+
num_channels_latents,
|
826 |
+
height,
|
827 |
+
width,
|
828 |
+
dtype,
|
829 |
+
device,
|
830 |
+
generator,
|
831 |
+
latents=None,
|
832 |
+
):
|
833 |
+
if image is None:
|
834 |
+
batch_size = batch_size * num_images_per_prompt
|
835 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
836 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
837 |
+
raise ValueError(
|
838 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
839 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
840 |
+
)
|
841 |
+
|
842 |
+
if latents is None:
|
843 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
844 |
+
else:
|
845 |
+
latents = latents.to(device)
|
846 |
+
|
847 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
848 |
+
latents = latents * self.scheduler.init_noise_sigma
|
849 |
+
return latents, None, None
|
850 |
+
else:
|
851 |
+
image = image.to(device=self.device, dtype=dtype)
|
852 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
853 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
854 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
855 |
+
|
856 |
+
# Expand init_latents for batch_size and num_images_per_prompt
|
857 |
+
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
858 |
+
init_latents_orig = init_latents
|
859 |
+
|
860 |
+
# add noise to latents using the timesteps
|
861 |
+
noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
|
862 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
863 |
+
latents = init_latents
|
864 |
+
return latents, init_latents_orig, noise
|
865 |
+
|
866 |
+
@torch.no_grad()
|
867 |
+
def __call__(
|
868 |
+
self,
|
869 |
+
prompt: Union[str, List[str]],
|
870 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
871 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
872 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
873 |
+
height: int = 512,
|
874 |
+
width: int = 512,
|
875 |
+
num_inference_steps: int = 50,
|
876 |
+
guidance_scale: float = 7.5,
|
877 |
+
strength: float = 0.8,
|
878 |
+
num_images_per_prompt: Optional[int] = 1,
|
879 |
+
add_predicted_noise: Optional[bool] = False,
|
880 |
+
eta: float = 0.0,
|
881 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
882 |
+
latents: Optional[torch.FloatTensor] = None,
|
883 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
884 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
885 |
+
max_embeddings_multiples: Optional[int] = 3,
|
886 |
+
output_type: Optional[str] = "pil",
|
887 |
+
return_dict: bool = True,
|
888 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
889 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
890 |
+
callback_steps: int = 1,
|
891 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
892 |
+
):
|
893 |
+
r"""
|
894 |
+
Function invoked when calling the pipeline for generation.
|
895 |
+
|
896 |
+
Args:
|
897 |
+
prompt (`str` or `List[str]`):
|
898 |
+
The prompt or prompts to guide the image generation.
|
899 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
900 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
901 |
+
if `guidance_scale` is less than `1`).
|
902 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
903 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
904 |
+
process.
|
905 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
906 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
907 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
908 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
909 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
910 |
+
height (`int`, *optional*, defaults to 512):
|
911 |
+
The height in pixels of the generated image.
|
912 |
+
width (`int`, *optional*, defaults to 512):
|
913 |
+
The width in pixels of the generated image.
|
914 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
915 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
916 |
+
expense of slower inference.
|
917 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
918 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
919 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
920 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
921 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
922 |
+
usually at the expense of lower image quality.
|
923 |
+
strength (`float`, *optional*, defaults to 0.8):
|
924 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
925 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
926 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
927 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
928 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
929 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
930 |
+
The number of images to generate per prompt.
|
931 |
+
add_predicted_noise (`bool`, *optional*, defaults to True):
|
932 |
+
Use predicted noise instead of random noise when constructing noisy versions of the original image in
|
933 |
+
the reverse diffusion process
|
934 |
+
eta (`float`, *optional*, defaults to 0.0):
|
935 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
936 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
937 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
938 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
939 |
+
to make generation deterministic.
|
940 |
+
latents (`torch.FloatTensor`, *optional*):
|
941 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
942 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
943 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
944 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
945 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
946 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
947 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
948 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
949 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
950 |
+
argument.
|
951 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
952 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
953 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
954 |
+
The output format of the generate image. Choose between
|
955 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
956 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
957 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
958 |
+
plain tuple.
|
959 |
+
callback (`Callable`, *optional*):
|
960 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
961 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
962 |
+
is_cancelled_callback (`Callable`, *optional*):
|
963 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
964 |
+
`True`, the inference will be cancelled.
|
965 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
966 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
967 |
+
called at every step.
|
968 |
+
cross_attention_kwargs (`dict`, *optional*):
|
969 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
970 |
+
`self.processor` in
|
971 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
972 |
+
|
973 |
+
Returns:
|
974 |
+
`None` if cancelled by `is_cancelled_callback`,
|
975 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
976 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
977 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
978 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
979 |
+
(nsfw) content, according to the `safety_checker`.
|
980 |
+
"""
|
981 |
+
# 0. Default height and width to unet
|
982 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
983 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
984 |
+
|
985 |
+
# 1. Check inputs. Raise error if not correct
|
986 |
+
self.check_inputs(
|
987 |
+
prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
988 |
+
)
|
989 |
+
|
990 |
+
# 2. Define call parameters
|
991 |
+
if prompt is not None and isinstance(prompt, str):
|
992 |
+
batch_size = 1
|
993 |
+
elif prompt is not None and isinstance(prompt, list):
|
994 |
+
batch_size = len(prompt)
|
995 |
+
else:
|
996 |
+
batch_size = prompt_embeds.shape[0]
|
997 |
+
|
998 |
+
device = self._execution_device
|
999 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1000 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1001 |
+
# corresponds to doing no classifier free guidance.
|
1002 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
1003 |
+
|
1004 |
+
# 3. Encode input prompt
|
1005 |
+
prompt_embeds = self._encode_prompt(
|
1006 |
+
prompt,
|
1007 |
+
device,
|
1008 |
+
num_images_per_prompt,
|
1009 |
+
do_classifier_free_guidance,
|
1010 |
+
negative_prompt,
|
1011 |
+
max_embeddings_multiples,
|
1012 |
+
prompt_embeds=prompt_embeds,
|
1013 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1014 |
+
)
|
1015 |
+
dtype = prompt_embeds.dtype
|
1016 |
+
|
1017 |
+
# 4. Preprocess image and mask
|
1018 |
+
if isinstance(image, PIL.Image.Image):
|
1019 |
+
image = preprocess_image(image, batch_size)
|
1020 |
+
if image is not None:
|
1021 |
+
image = image.to(device=self.device, dtype=dtype)
|
1022 |
+
if isinstance(mask_image, PIL.Image.Image):
|
1023 |
+
mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
|
1024 |
+
if mask_image is not None:
|
1025 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
1026 |
+
mask = torch.cat([mask] * num_images_per_prompt)
|
1027 |
+
else:
|
1028 |
+
mask = None
|
1029 |
+
|
1030 |
+
# 5. set timesteps
|
1031 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1032 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
1033 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
1034 |
+
|
1035 |
+
# 6. Prepare latent variables
|
1036 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
1037 |
+
image,
|
1038 |
+
latent_timestep,
|
1039 |
+
num_images_per_prompt,
|
1040 |
+
batch_size,
|
1041 |
+
self.unet.config.in_channels,
|
1042 |
+
height,
|
1043 |
+
width,
|
1044 |
+
dtype,
|
1045 |
+
device,
|
1046 |
+
generator,
|
1047 |
+
latents,
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1051 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1052 |
+
|
1053 |
+
# 8. Denoising loop
|
1054 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1055 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1056 |
+
for i, t in enumerate(timesteps):
|
1057 |
+
# expand the latents if we are doing classifier free guidance
|
1058 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1059 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1060 |
+
|
1061 |
+
# predict the noise residual
|
1062 |
+
noise_pred = self.unet(
|
1063 |
+
latent_model_input,
|
1064 |
+
t,
|
1065 |
+
encoder_hidden_states=prompt_embeds,
|
1066 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1067 |
+
).sample
|
1068 |
+
|
1069 |
+
# perform guidance
|
1070 |
+
if do_classifier_free_guidance:
|
1071 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1072 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1073 |
+
|
1074 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1075 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1076 |
+
|
1077 |
+
if mask is not None:
|
1078 |
+
# masking
|
1079 |
+
if add_predicted_noise:
|
1080 |
+
init_latents_proper = self.scheduler.add_noise(
|
1081 |
+
init_latents_orig, noise_pred_uncond, torch.tensor([t])
|
1082 |
+
)
|
1083 |
+
else:
|
1084 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
1085 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
1086 |
+
|
1087 |
+
# call the callback, if provided
|
1088 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1089 |
+
progress_bar.update()
|
1090 |
+
if i % callback_steps == 0:
|
1091 |
+
if callback is not None:
|
1092 |
+
callback(i, t, latents)
|
1093 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
1094 |
+
return None
|
1095 |
+
|
1096 |
+
if output_type == "latent":
|
1097 |
+
image = latents
|
1098 |
+
has_nsfw_concept = None
|
1099 |
+
elif output_type == "pil":
|
1100 |
+
# 9. Post-processing
|
1101 |
+
image = self.decode_latents(latents)
|
1102 |
+
|
1103 |
+
# 10. Run safety checker
|
1104 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1105 |
+
|
1106 |
+
# 11. Convert to PIL
|
1107 |
+
image = self.numpy_to_pil(image)
|
1108 |
+
else:
|
1109 |
+
# 9. Post-processing
|
1110 |
+
image = self.decode_latents(latents)
|
1111 |
+
|
1112 |
+
# 10. Run safety checker
|
1113 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1114 |
+
|
1115 |
+
# Offload last model to CPU
|
1116 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1117 |
+
self.final_offload_hook.offload()
|
1118 |
+
|
1119 |
+
if not return_dict:
|
1120 |
+
return image, has_nsfw_concept
|
1121 |
+
|
1122 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
1123 |
+
|
1124 |
+
def text2img(
|
1125 |
+
self,
|
1126 |
+
prompt: Union[str, List[str]],
|
1127 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1128 |
+
height: int = 512,
|
1129 |
+
width: int = 512,
|
1130 |
+
num_inference_steps: int = 50,
|
1131 |
+
guidance_scale: float = 7.5,
|
1132 |
+
num_images_per_prompt: Optional[int] = 1,
|
1133 |
+
eta: float = 0.0,
|
1134 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1135 |
+
latents: Optional[torch.FloatTensor] = None,
|
1136 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1137 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1138 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1139 |
+
output_type: Optional[str] = "pil",
|
1140 |
+
return_dict: bool = True,
|
1141 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1142 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1143 |
+
callback_steps: int = 1,
|
1144 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1145 |
+
):
|
1146 |
+
r"""
|
1147 |
+
Function for text-to-image generation.
|
1148 |
+
Args:
|
1149 |
+
prompt (`str` or `List[str]`):
|
1150 |
+
The prompt or prompts to guide the image generation.
|
1151 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1152 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1153 |
+
if `guidance_scale` is less than `1`).
|
1154 |
+
height (`int`, *optional*, defaults to 512):
|
1155 |
+
The height in pixels of the generated image.
|
1156 |
+
width (`int`, *optional*, defaults to 512):
|
1157 |
+
The width in pixels of the generated image.
|
1158 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1159 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1160 |
+
expense of slower inference.
|
1161 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1162 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1163 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1164 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1165 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1166 |
+
usually at the expense of lower image quality.
|
1167 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1168 |
+
The number of images to generate per prompt.
|
1169 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1170 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1171 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1172 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1173 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1174 |
+
to make generation deterministic.
|
1175 |
+
latents (`torch.FloatTensor`, *optional*):
|
1176 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1177 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1178 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1179 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1180 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1181 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1182 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1183 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1184 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1185 |
+
argument.
|
1186 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1187 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1188 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1189 |
+
The output format of the generate image. Choose between
|
1190 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1191 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1192 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1193 |
+
plain tuple.
|
1194 |
+
callback (`Callable`, *optional*):
|
1195 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1196 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1197 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1198 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1199 |
+
`True`, the inference will be cancelled.
|
1200 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1201 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1202 |
+
called at every step.
|
1203 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1204 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1205 |
+
`self.processor` in
|
1206 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
1207 |
+
|
1208 |
+
Returns:
|
1209 |
+
`None` if cancelled by `is_cancelled_callback`,
|
1210 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
1211 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1212 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1213 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1214 |
+
(nsfw) content, according to the `safety_checker`.
|
1215 |
+
"""
|
1216 |
+
return self.__call__(
|
1217 |
+
prompt=prompt,
|
1218 |
+
negative_prompt=negative_prompt,
|
1219 |
+
height=height,
|
1220 |
+
width=width,
|
1221 |
+
num_inference_steps=num_inference_steps,
|
1222 |
+
guidance_scale=guidance_scale,
|
1223 |
+
num_images_per_prompt=num_images_per_prompt,
|
1224 |
+
eta=eta,
|
1225 |
+
generator=generator,
|
1226 |
+
latents=latents,
|
1227 |
+
prompt_embeds=prompt_embeds,
|
1228 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1229 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1230 |
+
output_type=output_type,
|
1231 |
+
return_dict=return_dict,
|
1232 |
+
callback=callback,
|
1233 |
+
is_cancelled_callback=is_cancelled_callback,
|
1234 |
+
callback_steps=callback_steps,
|
1235 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1236 |
+
)
|
1237 |
+
|
1238 |
+
def img2img(
|
1239 |
+
self,
|
1240 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1241 |
+
prompt: Union[str, List[str]],
|
1242 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1243 |
+
strength: float = 0.8,
|
1244 |
+
num_inference_steps: Optional[int] = 50,
|
1245 |
+
guidance_scale: Optional[float] = 7.5,
|
1246 |
+
num_images_per_prompt: Optional[int] = 1,
|
1247 |
+
eta: Optional[float] = 0.0,
|
1248 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1249 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1250 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1251 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1252 |
+
output_type: Optional[str] = "pil",
|
1253 |
+
return_dict: bool = True,
|
1254 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1255 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1256 |
+
callback_steps: int = 1,
|
1257 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1258 |
+
):
|
1259 |
+
r"""
|
1260 |
+
Function for image-to-image generation.
|
1261 |
+
Args:
|
1262 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1263 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1264 |
+
process.
|
1265 |
+
prompt (`str` or `List[str]`):
|
1266 |
+
The prompt or prompts to guide the image generation.
|
1267 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1268 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1269 |
+
if `guidance_scale` is less than `1`).
|
1270 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1271 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
1272 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
1273 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
1274 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
1275 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
1276 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1277 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
1278 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
1279 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1280 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1281 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1282 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1283 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1284 |
+
usually at the expense of lower image quality.
|
1285 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1286 |
+
The number of images to generate per prompt.
|
1287 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1288 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1289 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1290 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1291 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1292 |
+
to make generation deterministic.
|
1293 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1294 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1295 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1296 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1297 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1298 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1299 |
+
argument.
|
1300 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1301 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1302 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1303 |
+
The output format of the generate image. Choose between
|
1304 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1305 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1306 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1307 |
+
plain tuple.
|
1308 |
+
callback (`Callable`, *optional*):
|
1309 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1310 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1311 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1312 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1313 |
+
`True`, the inference will be cancelled.
|
1314 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1315 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1316 |
+
called at every step.
|
1317 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1318 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1319 |
+
`self.processor` in
|
1320 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
1321 |
+
|
1322 |
+
Returns:
|
1323 |
+
`None` if cancelled by `is_cancelled_callback`,
|
1324 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1325 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1326 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1327 |
+
(nsfw) content, according to the `safety_checker`.
|
1328 |
+
"""
|
1329 |
+
return self.__call__(
|
1330 |
+
prompt=prompt,
|
1331 |
+
negative_prompt=negative_prompt,
|
1332 |
+
image=image,
|
1333 |
+
num_inference_steps=num_inference_steps,
|
1334 |
+
guidance_scale=guidance_scale,
|
1335 |
+
strength=strength,
|
1336 |
+
num_images_per_prompt=num_images_per_prompt,
|
1337 |
+
eta=eta,
|
1338 |
+
generator=generator,
|
1339 |
+
prompt_embeds=prompt_embeds,
|
1340 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1341 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1342 |
+
output_type=output_type,
|
1343 |
+
return_dict=return_dict,
|
1344 |
+
callback=callback,
|
1345 |
+
is_cancelled_callback=is_cancelled_callback,
|
1346 |
+
callback_steps=callback_steps,
|
1347 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1348 |
+
)
|
1349 |
+
|
1350 |
+
def inpaint(
|
1351 |
+
self,
|
1352 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
1353 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
1354 |
+
prompt: Union[str, List[str]],
|
1355 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
1356 |
+
strength: float = 0.8,
|
1357 |
+
num_inference_steps: Optional[int] = 50,
|
1358 |
+
guidance_scale: Optional[float] = 7.5,
|
1359 |
+
num_images_per_prompt: Optional[int] = 1,
|
1360 |
+
add_predicted_noise: Optional[bool] = False,
|
1361 |
+
eta: Optional[float] = 0.0,
|
1362 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1363 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
1364 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
1365 |
+
max_embeddings_multiples: Optional[int] = 3,
|
1366 |
+
output_type: Optional[str] = "pil",
|
1367 |
+
return_dict: bool = True,
|
1368 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1369 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1370 |
+
callback_steps: int = 1,
|
1371 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1372 |
+
):
|
1373 |
+
r"""
|
1374 |
+
Function for inpaint.
|
1375 |
+
Args:
|
1376 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1377 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
1378 |
+
process. This is the image whose masked region will be inpainted.
|
1379 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
1380 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
1381 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
1382 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
1383 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
1384 |
+
prompt (`str` or `List[str]`):
|
1385 |
+
The prompt or prompts to guide the image generation.
|
1386 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1387 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
1388 |
+
if `guidance_scale` is less than `1`).
|
1389 |
+
strength (`float`, *optional*, defaults to 0.8):
|
1390 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
1391 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
1392 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
1393 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
1394 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
1395 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
1396 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
1397 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
1398 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
1399 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
1400 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
1401 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1402 |
+
usually at the expense of lower image quality.
|
1403 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1404 |
+
The number of images to generate per prompt.
|
1405 |
+
add_predicted_noise (`bool`, *optional*, defaults to True):
|
1406 |
+
Use predicted noise instead of random noise when constructing noisy versions of the original image in
|
1407 |
+
the reverse diffusion process
|
1408 |
+
eta (`float`, *optional*, defaults to 0.0):
|
1409 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
1410 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
1411 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1412 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1413 |
+
to make generation deterministic.
|
1414 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1415 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1416 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1417 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1418 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1419 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1420 |
+
argument.
|
1421 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
1422 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
1423 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1424 |
+
The output format of the generate image. Choose between
|
1425 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1426 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1427 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
1428 |
+
plain tuple.
|
1429 |
+
callback (`Callable`, *optional*):
|
1430 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
1431 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
1432 |
+
is_cancelled_callback (`Callable`, *optional*):
|
1433 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
1434 |
+
`True`, the inference will be cancelled.
|
1435 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
1436 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1437 |
+
called at every step.
|
1438 |
+
cross_attention_kwargs (`dict`, *optional*):
|
1439 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1440 |
+
`self.processor` in
|
1441 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
1442 |
+
|
1443 |
+
Returns:
|
1444 |
+
`None` if cancelled by `is_cancelled_callback`,
|
1445 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
1446 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
1447 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1448 |
+
(nsfw) content, according to the `safety_checker`.
|
1449 |
+
"""
|
1450 |
+
return self.__call__(
|
1451 |
+
prompt=prompt,
|
1452 |
+
negative_prompt=negative_prompt,
|
1453 |
+
image=image,
|
1454 |
+
mask_image=mask_image,
|
1455 |
+
num_inference_steps=num_inference_steps,
|
1456 |
+
guidance_scale=guidance_scale,
|
1457 |
+
strength=strength,
|
1458 |
+
num_images_per_prompt=num_images_per_prompt,
|
1459 |
+
add_predicted_noise=add_predicted_noise,
|
1460 |
+
eta=eta,
|
1461 |
+
generator=generator,
|
1462 |
+
prompt_embeds=prompt_embeds,
|
1463 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1464 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1465 |
+
output_type=output_type,
|
1466 |
+
return_dict=return_dict,
|
1467 |
+
callback=callback,
|
1468 |
+
is_cancelled_callback=is_cancelled_callback,
|
1469 |
+
callback_steps=callback_steps,
|
1470 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1471 |
+
)
|
chat_anything/face_generator/utils/generate.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
3 |
+
|
4 |
+
@torch.no_grad()
|
5 |
+
def generate(pipe, prompt, negative_prompt, **generating_conf):
|
6 |
+
pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(
|
7 |
+
unet=pipe.unet,
|
8 |
+
text_encoder=pipe.text_encoder,
|
9 |
+
vae=pipe.vae,
|
10 |
+
tokenizer=pipe.tokenizer,
|
11 |
+
scheduler=pipe.scheduler,
|
12 |
+
safety_checker=None,
|
13 |
+
feature_extractor=None,
|
14 |
+
)
|
15 |
+
print('generating: ', prompt)
|
16 |
+
print('using negative prompt: ', negative_prompt)
|
17 |
+
embeds = pipe_longprompt._encode_prompt(prompt=prompt, negative_prompt=negative_prompt, device=pipe.device, num_images_per_prompt=1, do_classifier_free_guidance=generating_conf['guidance_scale']>1,)
|
18 |
+
negative_prompt_embeds, prompt_embeds = embeds.split(embeds.shape[0]//2)
|
19 |
+
pipe_out = pipe(
|
20 |
+
prompt_embeds=prompt_embeds,
|
21 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
22 |
+
**generating_conf,
|
23 |
+
)
|
24 |
+
return pipe_out
|
25 |
+
|
26 |
+
if __name__ == '__main__':
|
27 |
+
from diffusers.pipelines import StableDiffusionPipeline
|
28 |
+
import argparse
|
29 |
+
def main():
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument(
|
32 |
+
'--prompts',type=str,default=['starry night','Impression Sunrise, drawn by Claude Monet'], nargs='*'
|
33 |
+
)
|
34 |
+
|
35 |
+
args = parser.parse_args()
|
36 |
+
prompts = args.prompts
|
37 |
+
print(f'generating {prompts}')
|
38 |
+
model_id = 'pretrained_model/sd-v1-4'
|
39 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id,).to('cuda')
|
40 |
+
images = pipe(prompts).images
|
41 |
+
for i, image in enumerate(images):
|
42 |
+
image.save(f'{prompts[i]}_{i}.png')
|
43 |
+
|
44 |
+
main()
|
45 |
+
|
chat_anything/polly_utils.py
ADDED
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This class stores Polly voice data. Specifically, the class stores several records containing
|
2 |
+
# language, lang_code, gender, voice_id and engine. The class also has a method to return the
|
3 |
+
# voice_id, lang_code and engine given a language and gender.
|
4 |
+
|
5 |
+
NEURAL_ENGINE = "neural"
|
6 |
+
STANDARD_ENGINE = "standard"
|
7 |
+
|
8 |
+
|
9 |
+
class PollyVoiceData:
|
10 |
+
def get_voice(self, language, gender):
|
11 |
+
for voice in self.voice_data:
|
12 |
+
if voice['language'] == language and voice['gender'] == gender:
|
13 |
+
if voice['neural'] == 'Yes':
|
14 |
+
return voice['voice_id'], voice['lang_code'], NEURAL_ENGINE
|
15 |
+
for voice in self.voice_data:
|
16 |
+
if voice['language'] == language and voice['gender'] == gender:
|
17 |
+
if voice['standard'] == 'Yes':
|
18 |
+
return voice['voice_id'], voice['lang_code'], STANDARD_ENGINE
|
19 |
+
return None, None, None
|
20 |
+
|
21 |
+
def get_whisper_lang_code(self, language):
|
22 |
+
for voice in self.voice_data:
|
23 |
+
if voice['language'] == language:
|
24 |
+
return voice['whisper_lang_code']
|
25 |
+
return "en"
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self.voice_data = [
|
29 |
+
{'language': 'Arabic',
|
30 |
+
'lang_code': 'arb',
|
31 |
+
'whisper_lang_code': 'ar',
|
32 |
+
'voice_id': 'Zeina',
|
33 |
+
'gender': 'Female',
|
34 |
+
'neural': 'No',
|
35 |
+
'standard': 'Yes'},
|
36 |
+
{'language': 'Arabic (Gulf)',
|
37 |
+
'lang_code': 'ar-AE',
|
38 |
+
'whisper_lang_code': 'ar',
|
39 |
+
'voice_id': 'Hala',
|
40 |
+
'gender': 'Female',
|
41 |
+
'neural': 'Yes',
|
42 |
+
'standard': 'No'},
|
43 |
+
{'language': 'Catalan',
|
44 |
+
'lang_code': 'ca-ES',
|
45 |
+
'whisper_lang_code': 'ca',
|
46 |
+
'voice_id': 'Arlet',
|
47 |
+
'gender': 'Female',
|
48 |
+
'neural': 'Yes',
|
49 |
+
'standard': 'No'},
|
50 |
+
{'language': 'Chinese (Cantonese)',
|
51 |
+
'lang_code': 'yue-CN',
|
52 |
+
'whisper_lang_code': 'zh',
|
53 |
+
'voice_id': 'Hiujin',
|
54 |
+
'gender': 'Female',
|
55 |
+
'neural': 'Yes',
|
56 |
+
'standard': 'No'},
|
57 |
+
{'language': 'Chinese (Mandarin)',
|
58 |
+
'lang_code': 'cmn-CN',
|
59 |
+
'whisper_lang_code': 'zh',
|
60 |
+
'voice_id': 'Zhiyu',
|
61 |
+
'gender': 'Female',
|
62 |
+
'neural': 'Yes',
|
63 |
+
'standard': 'No'},
|
64 |
+
{'language': 'Danish',
|
65 |
+
'lang_code': 'da-DK',
|
66 |
+
'whisper_lang_code': 'da',
|
67 |
+
'voice_id': 'Naja',
|
68 |
+
'gender': 'Female',
|
69 |
+
'neural': 'No',
|
70 |
+
'standard': 'Yes'},
|
71 |
+
{'language': 'Danish',
|
72 |
+
'lang_code': 'da-DK',
|
73 |
+
'whisper_lang_code': 'da',
|
74 |
+
'voice_id': 'Mads',
|
75 |
+
'gender': 'Male',
|
76 |
+
'neural': 'No',
|
77 |
+
'standard': 'Yes'},
|
78 |
+
{'language': 'Dutch',
|
79 |
+
'lang_code': 'nl-NL',
|
80 |
+
'whisper_lang_code': 'nl',
|
81 |
+
'voice_id': 'Laura',
|
82 |
+
'gender': 'Female',
|
83 |
+
'neural': 'Yes',
|
84 |
+
'standard': 'No'},
|
85 |
+
{'language': 'Dutch',
|
86 |
+
'lang_code': 'nl-NL',
|
87 |
+
'whisper_lang_code': 'nl',
|
88 |
+
'voice_id': 'Lotte',
|
89 |
+
'gender': 'Female',
|
90 |
+
'neural': 'No',
|
91 |
+
'standard': 'Yes'},
|
92 |
+
{'language': 'Dutch',
|
93 |
+
'lang_code': 'nl-NL',
|
94 |
+
'whisper_lang_code': 'nl',
|
95 |
+
'voice_id': 'Ruben',
|
96 |
+
'gender': 'Male',
|
97 |
+
'neural': 'No',
|
98 |
+
'standard': 'Yes'},
|
99 |
+
{'language': 'English (Australian)',
|
100 |
+
'lang_code': 'en-AU',
|
101 |
+
'whisper_lang_code': 'en',
|
102 |
+
'voice_id': 'Nicole',
|
103 |
+
'gender': 'Female',
|
104 |
+
'neural': 'No',
|
105 |
+
'standard': 'Yes'},
|
106 |
+
{'language': 'English (Australian)',
|
107 |
+
'lang_code': 'en-AU',
|
108 |
+
'whisper_lang_code': 'en',
|
109 |
+
'voice_id': 'Olivia',
|
110 |
+
'gender': 'Female',
|
111 |
+
'neural': 'Yes',
|
112 |
+
'standard': 'No'},
|
113 |
+
{'language': 'English (Australian)',
|
114 |
+
'lang_code': 'en-AU',
|
115 |
+
'whisper_lang_code': 'en',
|
116 |
+
'voice_id': 'Russell',
|
117 |
+
'gender': 'Male',
|
118 |
+
'neural': 'No',
|
119 |
+
'standard': 'Yes'},
|
120 |
+
{'language': 'English (British)',
|
121 |
+
'lang_code': 'en-GB',
|
122 |
+
'whisper_lang_code': 'en',
|
123 |
+
'voice_id': 'Amy',
|
124 |
+
'gender': 'Female',
|
125 |
+
'neural': 'Yes',
|
126 |
+
'standard': 'Yes'},
|
127 |
+
{'language': 'English (British)',
|
128 |
+
'lang_code': 'en-GB',
|
129 |
+
'whisper_lang_code': 'en',
|
130 |
+
'voice_id': 'Emma',
|
131 |
+
'gender': 'Female',
|
132 |
+
'neural': 'Yes',
|
133 |
+
'standard': 'Yes'},
|
134 |
+
{'language': 'English (British)',
|
135 |
+
'lang_code': 'en-GB',
|
136 |
+
'whisper_lang_code': 'en',
|
137 |
+
'voice_id': 'Brian',
|
138 |
+
'gender': 'Male',
|
139 |
+
'neural': 'Yes',
|
140 |
+
'standard': 'Yes'},
|
141 |
+
{'language': 'English (British)',
|
142 |
+
'lang_code': 'en-GB',
|
143 |
+
'whisper_lang_code': 'en',
|
144 |
+
'voice_id': 'Arthur',
|
145 |
+
'gender': 'Male',
|
146 |
+
'neural': 'Yes',
|
147 |
+
'standard': 'No'},
|
148 |
+
{'language': 'English (Indian)',
|
149 |
+
'lang_code': 'en-IN',
|
150 |
+
'whisper_lang_code': 'en',
|
151 |
+
'voice_id': 'Aditi',
|
152 |
+
'gender': 'Female',
|
153 |
+
'neural': 'No',
|
154 |
+
'standard': 'Yes'},
|
155 |
+
{'language': 'English (Indian)',
|
156 |
+
'lang_code': 'en-IN',
|
157 |
+
'whisper_lang_code': 'en',
|
158 |
+
'voice_id': 'Raveena',
|
159 |
+
'gender': 'Female',
|
160 |
+
'neural': 'No',
|
161 |
+
'standard': 'Yes'},
|
162 |
+
{'language': 'English (Indian)',
|
163 |
+
'lang_code': 'en-IN',
|
164 |
+
'whisper_lang_code': 'en',
|
165 |
+
'voice_id': 'Kajal',
|
166 |
+
'gender': 'Female',
|
167 |
+
'neural': 'Yes',
|
168 |
+
'standard': 'No'},
|
169 |
+
{'language': 'English (New Zealand)',
|
170 |
+
'lang_code': 'en-NZ',
|
171 |
+
'whisper_lang_code': 'en',
|
172 |
+
'voice_id': 'Aria',
|
173 |
+
'gender': 'Female',
|
174 |
+
'neural': 'Yes',
|
175 |
+
'standard': 'No'},
|
176 |
+
{'language': 'English (South African)',
|
177 |
+
'lang_code': 'en-ZA',
|
178 |
+
'whisper_lang_code': 'en',
|
179 |
+
'voice_id': 'Ayanda',
|
180 |
+
'gender': 'Female',
|
181 |
+
'neural': 'Yes',
|
182 |
+
'standard': 'No'},
|
183 |
+
{'language': 'English (US)',
|
184 |
+
'lang_code': 'en-US',
|
185 |
+
'whisper_lang_code': 'en',
|
186 |
+
'voice_id': 'Ivy',
|
187 |
+
'gender': 'Female (child)',
|
188 |
+
'neural': 'Yes',
|
189 |
+
'standard': 'Yes'},
|
190 |
+
{'language': 'English (US)',
|
191 |
+
'lang_code': 'en-US',
|
192 |
+
'whisper_lang_code': 'en',
|
193 |
+
'voice_id': 'Joanna',
|
194 |
+
'gender': 'Female',
|
195 |
+
'neural': 'Yes',
|
196 |
+
'standard': 'Yes'},
|
197 |
+
{'language': 'English (US)',
|
198 |
+
'lang_code': 'en-US',
|
199 |
+
'whisper_lang_code': 'en',
|
200 |
+
'voice_id': 'Kendra',
|
201 |
+
'gender': 'Female',
|
202 |
+
'neural': 'Yes',
|
203 |
+
'standard': 'Yes'},
|
204 |
+
{'language': 'English (US)',
|
205 |
+
'lang_code': 'en-US',
|
206 |
+
'whisper_lang_code': 'en',
|
207 |
+
'voice_id': 'Kimberly',
|
208 |
+
'gender': 'Female',
|
209 |
+
'neural': 'Yes',
|
210 |
+
'standard': 'Yes'},
|
211 |
+
{'language': 'English (US)',
|
212 |
+
'lang_code': 'en-US',
|
213 |
+
'whisper_lang_code': 'en',
|
214 |
+
'voice_id': 'Salli',
|
215 |
+
'gender': 'Female',
|
216 |
+
'neural': 'Yes',
|
217 |
+
'standard': 'Yes'},
|
218 |
+
{'language': 'English (US)',
|
219 |
+
'lang_code': 'en-US',
|
220 |
+
'whisper_lang_code': 'en',
|
221 |
+
'voice_id': 'Joey',
|
222 |
+
'gender': 'Male',
|
223 |
+
'neural': 'Yes',
|
224 |
+
'standard': 'Yes'},
|
225 |
+
{'language': 'English (US)',
|
226 |
+
'lang_code': 'en-US',
|
227 |
+
'whisper_lang_code': 'en',
|
228 |
+
'voice_id': 'Justin',
|
229 |
+
'gender': 'Male (child)',
|
230 |
+
'neural': 'Yes',
|
231 |
+
'standard': 'Yes'},
|
232 |
+
{'language': 'English (US)',
|
233 |
+
'lang_code': 'en-US',
|
234 |
+
'whisper_lang_code': 'en',
|
235 |
+
'voice_id': 'Kevin',
|
236 |
+
'gender': 'Male (child)',
|
237 |
+
'neural': 'Yes',
|
238 |
+
'standard': 'No'},
|
239 |
+
{'language': 'English (US)',
|
240 |
+
'lang_code': 'en-US',
|
241 |
+
'whisper_lang_code': 'en',
|
242 |
+
'voice_id': 'Matthew',
|
243 |
+
'gender': 'Male',
|
244 |
+
'neural': 'Yes',
|
245 |
+
'standard': 'Yes'},
|
246 |
+
{'language': 'English (Welsh)',
|
247 |
+
'lang_code': 'en-GB-WLS',
|
248 |
+
'whisper_lang_code': 'en',
|
249 |
+
'voice_id': 'Geraint',
|
250 |
+
'gender': 'Male',
|
251 |
+
'neural': 'No',
|
252 |
+
'standard': 'Yes'},
|
253 |
+
{'language': 'Finnish',
|
254 |
+
'lang_code': 'fi-FI',
|
255 |
+
'whisper_lang_code': 'fi',
|
256 |
+
'voice_id': 'Suvi',
|
257 |
+
'gender': 'Female',
|
258 |
+
'neural': 'Yes',
|
259 |
+
'standard': 'No'},
|
260 |
+
{'language': 'French',
|
261 |
+
'lang_code': 'fr-FR',
|
262 |
+
'whisper_lang_code': 'fr',
|
263 |
+
'voice_id': 'Celine',
|
264 |
+
'gender': 'Female',
|
265 |
+
'neural': 'No',
|
266 |
+
'standard': 'Yes'},
|
267 |
+
{'language': 'French',
|
268 |
+
'lang_code': 'fr-FR',
|
269 |
+
'whisper_lang_code': 'fr',
|
270 |
+
'voice_id': 'Lea',
|
271 |
+
'gender': 'Female',
|
272 |
+
'neural': 'Yes',
|
273 |
+
'standard': 'Yes'},
|
274 |
+
{'language': 'French',
|
275 |
+
'lang_code': 'fr-FR',
|
276 |
+
'whisper_lang_code': 'fr',
|
277 |
+
'voice_id': 'Mathieu',
|
278 |
+
'gender': 'Male',
|
279 |
+
'neural': 'No',
|
280 |
+
'standard': 'Yes'},
|
281 |
+
{'language': 'French (Canadian)',
|
282 |
+
'lang_code': 'fr-CA',
|
283 |
+
'whisper_lang_code': 'fr',
|
284 |
+
'voice_id': 'Chantal',
|
285 |
+
'gender': 'Female',
|
286 |
+
'neural': 'No',
|
287 |
+
'standard': 'Yes'},
|
288 |
+
{'language': 'French (Canadian)',
|
289 |
+
'lang_code': 'fr-CA',
|
290 |
+
'whisper_lang_code': 'fr',
|
291 |
+
'voice_id': 'Gabrielle',
|
292 |
+
'gender': 'Female',
|
293 |
+
'neural': 'Yes',
|
294 |
+
'standard': 'No'},
|
295 |
+
{'language': 'French (Canadian)',
|
296 |
+
'lang_code': 'fr-CA',
|
297 |
+
'whisper_lang_code': 'fr',
|
298 |
+
'voice_id': 'Liam',
|
299 |
+
'gender': 'Male',
|
300 |
+
'neural': 'Yes',
|
301 |
+
'standard': 'No'},
|
302 |
+
{'language': 'German',
|
303 |
+
'lang_code': 'de-DE',
|
304 |
+
'whisper_lang_code': 'de',
|
305 |
+
'voice_id': 'Marlene',
|
306 |
+
'gender': 'Female',
|
307 |
+
'neural': 'No',
|
308 |
+
'standard': 'Yes'},
|
309 |
+
{'language': 'German',
|
310 |
+
'lang_code': 'de-DE',
|
311 |
+
'whisper_lang_code': 'de',
|
312 |
+
'voice_id': 'Vicki',
|
313 |
+
'gender': 'Female',
|
314 |
+
'neural': 'Yes',
|
315 |
+
'standard': 'Yes'},
|
316 |
+
{'language': 'German',
|
317 |
+
'lang_code': 'de-DE',
|
318 |
+
'whisper_lang_code': 'de',
|
319 |
+
'voice_id': 'Hans',
|
320 |
+
'gender': 'Male',
|
321 |
+
'neural': 'No',
|
322 |
+
'standard': 'Yes'},
|
323 |
+
{'language': 'German',
|
324 |
+
'lang_code': 'de-DE',
|
325 |
+
'whisper_lang_code': 'de',
|
326 |
+
'voice_id': 'Daniel',
|
327 |
+
'gender': 'Male',
|
328 |
+
'neural': 'Yes',
|
329 |
+
'standard': 'No'},
|
330 |
+
{'language': 'German (Austrian)',
|
331 |
+
'lang_code': 'de-AT',
|
332 |
+
'whisper_lang_code': 'de',
|
333 |
+
'voice_id': 'Hannah',
|
334 |
+
'gender': 'Female',
|
335 |
+
'neural': 'Yes',
|
336 |
+
'standard': 'No'},
|
337 |
+
{'language': 'Hindi',
|
338 |
+
'lang_code': 'hi-IN',
|
339 |
+
'whisper_lang_code': 'hi',
|
340 |
+
'voice_id': 'Aditi',
|
341 |
+
'gender': 'Female',
|
342 |
+
'neural': 'No',
|
343 |
+
'standard': 'Yes'},
|
344 |
+
{'language': 'Hindi',
|
345 |
+
'lang_code': 'hi-IN',
|
346 |
+
'whisper_lang_code': 'hi',
|
347 |
+
'voice_id': 'Kajal',
|
348 |
+
'gender': 'Female',
|
349 |
+
'neural': 'Yes',
|
350 |
+
'standard': 'No'},
|
351 |
+
{'language': 'Icelandic',
|
352 |
+
'lang_code': 'is-IS',
|
353 |
+
'whisper_lang_code': 'is',
|
354 |
+
'voice_id': 'Dora',
|
355 |
+
'gender': 'Female',
|
356 |
+
'neural': 'No',
|
357 |
+
'standard': 'Yes'},
|
358 |
+
{'language': 'Icelandic',
|
359 |
+
'lang_code': 'is-IS',
|
360 |
+
'whisper_lang_code': 'is',
|
361 |
+
'voice_id': 'Karl',
|
362 |
+
'gender': 'Male',
|
363 |
+
'neural': 'No',
|
364 |
+
'standard': 'Yes'},
|
365 |
+
{'language': 'Italian',
|
366 |
+
'lang_code': 'it-IT',
|
367 |
+
'whisper_lang_code': 'it',
|
368 |
+
'voice_id': 'Carla',
|
369 |
+
'gender': 'Female',
|
370 |
+
'neural': 'No',
|
371 |
+
'standard': 'Yes'},
|
372 |
+
{'language': 'Italian',
|
373 |
+
'lang_code': 'it-IT',
|
374 |
+
'whisper_lang_code': 'it',
|
375 |
+
'voice_id': 'Bianca',
|
376 |
+
'gender': 'Female',
|
377 |
+
'neural': 'Yes',
|
378 |
+
'standard': 'Yes'},
|
379 |
+
{'language': 'Japanese',
|
380 |
+
'lang_code': 'ja-JP',
|
381 |
+
'whisper_lang_code': 'ja',
|
382 |
+
'voice_id': 'Mizuki',
|
383 |
+
'gender': 'Female',
|
384 |
+
'neural': 'No',
|
385 |
+
'standard': 'Yes'},
|
386 |
+
{'language': 'Japanese',
|
387 |
+
'lang_code': 'ja-JP',
|
388 |
+
'whisper_lang_code': 'ja',
|
389 |
+
'voice_id': 'Takumi',
|
390 |
+
'gender': 'Male',
|
391 |
+
'neural': 'Yes',
|
392 |
+
'standard': 'Yes'},
|
393 |
+
{'language': 'Korean',
|
394 |
+
'lang_code': 'ko-KR',
|
395 |
+
'whisper_lang_code': 'ko',
|
396 |
+
'voice_id': 'Seoyeon',
|
397 |
+
'gender': 'Female',
|
398 |
+
'neural': 'Yes',
|
399 |
+
'standard': 'Yes'},
|
400 |
+
{'language': 'Norwegian',
|
401 |
+
'lang_code': 'nb-NO',
|
402 |
+
'whisper_lang_code': 'no',
|
403 |
+
'voice_id': 'Liv',
|
404 |
+
'gender': 'Female',
|
405 |
+
'neural': 'No',
|
406 |
+
'standard': 'Yes'},
|
407 |
+
{'language': 'Norwegian',
|
408 |
+
'lang_code': 'nb-NO',
|
409 |
+
'whisper_lang_code': 'no',
|
410 |
+
'voice_id': 'Ida',
|
411 |
+
'gender': 'Female',
|
412 |
+
'neural': 'Yes',
|
413 |
+
'standard': 'No'},
|
414 |
+
{'language': 'Polish',
|
415 |
+
'lang_code': 'pl-PL',
|
416 |
+
'whisper_lang_code': 'pl',
|
417 |
+
'voice_id': 'Ewa',
|
418 |
+
'gender': 'Female',
|
419 |
+
'neural': 'No',
|
420 |
+
'standard': 'Yes'},
|
421 |
+
{'language': 'Polish',
|
422 |
+
'lang_code': 'pl-PL',
|
423 |
+
'whisper_lang_code': 'pl',
|
424 |
+
'voice_id': 'Maja',
|
425 |
+
'gender': 'Female',
|
426 |
+
'neural': 'No',
|
427 |
+
'standard': 'Yes'},
|
428 |
+
{'language': 'Polish',
|
429 |
+
'lang_code': 'pl-PL',
|
430 |
+
'whisper_lang_code': 'pl',
|
431 |
+
'voice_id': 'Jacek',
|
432 |
+
'gender': 'Male',
|
433 |
+
'neural': 'No',
|
434 |
+
'standard': 'Yes'},
|
435 |
+
{'language': 'Polish',
|
436 |
+
'lang_code': 'pl-PL',
|
437 |
+
'whisper_lang_code': 'pl',
|
438 |
+
'voice_id': 'Jan',
|
439 |
+
'gender': 'Male',
|
440 |
+
'neural': 'No',
|
441 |
+
'standard': 'Yes'},
|
442 |
+
{'language': 'Polish',
|
443 |
+
'lang_code': 'pl-PL',
|
444 |
+
'whisper_lang_code': 'pl',
|
445 |
+
'voice_id': 'Ola',
|
446 |
+
'gender': 'Female',
|
447 |
+
'neural': 'Yes',
|
448 |
+
'standard': 'No'},
|
449 |
+
{'language': 'Portuguese (Brazilian)',
|
450 |
+
'lang_code': 'pt-BR',
|
451 |
+
'whisper_lang_code': 'pt',
|
452 |
+
'voice_id': 'Camila',
|
453 |
+
'gender': 'Female',
|
454 |
+
'neural': 'Yes',
|
455 |
+
'standard': 'Yes'},
|
456 |
+
{'language': 'Portuguese (Brazilian)',
|
457 |
+
'lang_code': 'pt-BR',
|
458 |
+
'whisper_lang_code': 'pt',
|
459 |
+
'voice_id': 'Vitoria',
|
460 |
+
'gender': 'Female',
|
461 |
+
'neural': 'Yes',
|
462 |
+
'standard': 'Yes'},
|
463 |
+
{'language': 'Portuguese (Brazilian)',
|
464 |
+
'lang_code': 'pt-BR',
|
465 |
+
'whisper_lang_code': 'pt',
|
466 |
+
'voice_id': 'Ricardo',
|
467 |
+
'gender': 'Male',
|
468 |
+
'neural': 'No',
|
469 |
+
'standard': 'Yes'},
|
470 |
+
{'language': 'Portuguese (European)',
|
471 |
+
'lang_code': 'pt-PT',
|
472 |
+
'whisper_lang_code': 'pt',
|
473 |
+
'voice_id': 'Ines',
|
474 |
+
'gender': 'Female',
|
475 |
+
'neural': 'Yes',
|
476 |
+
'standard': 'Yes'},
|
477 |
+
{'language': 'Portuguese (European)',
|
478 |
+
'lang_code': 'pt-PT',
|
479 |
+
'whisper_lang_code': 'pt',
|
480 |
+
'voice_id': 'Cristiano',
|
481 |
+
'gender': 'Male',
|
482 |
+
'neural': 'No',
|
483 |
+
'standard': 'Yes'},
|
484 |
+
{'language': 'Romanian',
|
485 |
+
'lang_code': 'ro-RO',
|
486 |
+
'whisper_lang_code': 'ro',
|
487 |
+
'voice_id': 'Carmen',
|
488 |
+
'gender': 'Female',
|
489 |
+
'neural': 'No',
|
490 |
+
'standard': 'Yes'},
|
491 |
+
{'language': 'Russian',
|
492 |
+
'lang_code': 'ru-RU',
|
493 |
+
'whisper_lang_code': 'ru',
|
494 |
+
'voice_id': 'Tatyana',
|
495 |
+
'gender': 'Female',
|
496 |
+
'neural': 'No',
|
497 |
+
'standard': 'Yes'},
|
498 |
+
{'language': 'Russian',
|
499 |
+
'lang_code': 'ru-RU',
|
500 |
+
'whisper_lang_code': 'ru',
|
501 |
+
'voice_id': 'Maxim',
|
502 |
+
'gender': 'Male',
|
503 |
+
'neural': 'No',
|
504 |
+
'standard': 'Yes'},
|
505 |
+
{'language': 'Spanish (European)',
|
506 |
+
'lang_code': 'es-ES',
|
507 |
+
'whisper_lang_code': 'es',
|
508 |
+
'voice_id': 'Conchita',
|
509 |
+
'gender': 'Female',
|
510 |
+
'neural': 'No',
|
511 |
+
'standard': 'Yes'},
|
512 |
+
{'language': 'Spanish (European)',
|
513 |
+
'lang_code': 'es-ES',
|
514 |
+
'whisper_lang_code': 'es',
|
515 |
+
'voice_id': 'Lucia',
|
516 |
+
'gender': 'Female',
|
517 |
+
'neural': 'Yes',
|
518 |
+
'standard': 'Yes'},
|
519 |
+
{'language': 'Spanish (European)',
|
520 |
+
'lang_code': 'es-ES',
|
521 |
+
'whisper_lang_code': 'es',
|
522 |
+
'voice_id': 'Enrique',
|
523 |
+
'gender': 'Male',
|
524 |
+
'neural': 'No',
|
525 |
+
'standard': 'Yes'},
|
526 |
+
{'language': 'Spanish (Mexican)',
|
527 |
+
'lang_code': 'es-MX',
|
528 |
+
'whisper_lang_code': 'es',
|
529 |
+
'voice_id': 'Mia',
|
530 |
+
'gender': 'Female',
|
531 |
+
'neural': 'Yes',
|
532 |
+
'standard': 'Yes'},
|
533 |
+
{'language': 'Spanish (US)',
|
534 |
+
'lang_code': 'es-US',
|
535 |
+
'whisper_lang_code': 'es',
|
536 |
+
'voice_id': 'Lupe',
|
537 |
+
'gender': 'Female',
|
538 |
+
'neural': 'Yes',
|
539 |
+
'standard': 'Yes'},
|
540 |
+
{'language': 'Spanish (US)',
|
541 |
+
'lang_code': 'es-US',
|
542 |
+
'whisper_lang_code': 'es',
|
543 |
+
'voice_id': 'Penelope',
|
544 |
+
'gender': 'Female',
|
545 |
+
'neural': 'No',
|
546 |
+
'standard': 'Yes'},
|
547 |
+
{'language': 'Spanish (US)',
|
548 |
+
'lang_code': 'es-US',
|
549 |
+
'whisper_lang_code': 'es',
|
550 |
+
'voice_id': 'Miguel',
|
551 |
+
'gender': 'Male',
|
552 |
+
'neural': 'No',
|
553 |
+
'standard': 'Yes'},
|
554 |
+
{'language': 'Spanish (US)',
|
555 |
+
'lang_code': 'es-US',
|
556 |
+
'whisper_lang_code': 'es',
|
557 |
+
'voice_id': 'Pedro',
|
558 |
+
'gender': 'Male',
|
559 |
+
'neural': 'Yes',
|
560 |
+
'standard': 'No'},
|
561 |
+
{'language': 'Swedish',
|
562 |
+
'lang_code': 'sv-SE',
|
563 |
+
'whisper_lang_code': 'sv',
|
564 |
+
'voice_id': 'Astrid',
|
565 |
+
'gender': 'Female',
|
566 |
+
'neural': 'No',
|
567 |
+
'standard': 'Yes'},
|
568 |
+
{'language': 'Swedish',
|
569 |
+
'lang_code': 'sv-SE',
|
570 |
+
'whisper_lang_code': 'sv',
|
571 |
+
'voice_id': 'Elin',
|
572 |
+
'gender': 'Female',
|
573 |
+
'neural': 'Yes',
|
574 |
+
'standard': 'No'},
|
575 |
+
{'language': 'Turkish',
|
576 |
+
'lang_code': 'tr-TR',
|
577 |
+
'whisper_lang_code': 'tr',
|
578 |
+
'voice_id': 'Filiz',
|
579 |
+
'gender': 'Female',
|
580 |
+
'neural': 'No',
|
581 |
+
'standard': 'Yes'},
|
582 |
+
{'language': 'Welsh',
|
583 |
+
'lang_code': 'cy-GB',
|
584 |
+
'whisper_lang_code': 'cy',
|
585 |
+
'voice_id': 'Gwyneth',
|
586 |
+
'gender': 'Female',
|
587 |
+
'neural': 'No',
|
588 |
+
'standard': 'Yes'}
|
589 |
+
]
|
590 |
+
|
591 |
+
|
592 |
+
# Run from the command-line
|
593 |
+
if __name__ == '__main__':
|
594 |
+
polly_voice_data = PollyVoiceData()
|
595 |
+
|
596 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('English (US)', 'Male')
|
597 |
+
print('English (US)', 'Male', voice_id, language_code, engine)
|
598 |
+
|
599 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('English (US)', 'Female')
|
600 |
+
print('English (US)', 'Female', voice_id, language_code, engine)
|
601 |
+
|
602 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('French', 'Female')
|
603 |
+
print('French', 'Female', voice_id, language_code, engine)
|
604 |
+
|
605 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('French', 'Male')
|
606 |
+
print('French', 'Male', voice_id, language_code, engine)
|
607 |
+
|
608 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('Japanese', 'Female')
|
609 |
+
print('Japanese', 'Female', voice_id, language_code, engine)
|
610 |
+
|
611 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('Japanese', 'Male')
|
612 |
+
print('Japanese', 'Male', voice_id, language_code, engine)
|
613 |
+
|
614 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('Hindi', 'Female')
|
615 |
+
print('Hindi', 'Female', voice_id, language_code, engine)
|
616 |
+
|
617 |
+
voice_id, language_code, engine = polly_voice_data.get_voice('Hindi', 'Male')
|
618 |
+
print('Hindi', 'Male', voice_id, language_code, engine)
|
619 |
+
|
620 |
+
whisper_lang_code = polly_voice_data.get_whisper_lang_code('English (US)')
|
621 |
+
print('English (US) whisper_lang_code:', whisper_lang_code)
|
622 |
+
|
623 |
+
whisper_lang_code = polly_voice_data.get_whisper_lang_code('Chinese (Mandarin)')
|
624 |
+
print('Chinese (Mandarin) whisper_lang_code:', whisper_lang_code)
|
625 |
+
|
626 |
+
whisper_lang_code = polly_voice_data.get_whisper_lang_code('Norwegian')
|
627 |
+
print('Norwegian whisper_lang_code:', whisper_lang_code)
|
628 |
+
|
629 |
+
whisper_lang_code = polly_voice_data.get_whisper_lang_code('Dutch')
|
630 |
+
print('Dutch whisper_lang_code:', whisper_lang_code)
|
631 |
+
|
632 |
+
whisper_lang_code = polly_voice_data.get_whisper_lang_code('Foo')
|
633 |
+
print('Foo whisper_lang_code:', whisper_lang_code)
|
634 |
+
|
635 |
+
|
chat_anything/sad_talker/__init__.py
ADDED
File without changes
|
chat_anything/sad_talker/audio2exp_models/audio2exp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Audio2Exp(nn.Module):
|
7 |
+
def __init__(self, netG, cfg, device, prepare_training_loss=False):
|
8 |
+
super(Audio2Exp, self).__init__()
|
9 |
+
self.cfg = cfg
|
10 |
+
self.device = device
|
11 |
+
self.netG = netG.to(device)
|
12 |
+
|
13 |
+
def test(self, batch):
|
14 |
+
|
15 |
+
mel_input = batch['indiv_mels'] # bs T 1 80 16
|
16 |
+
bs = mel_input.shape[0]
|
17 |
+
T = mel_input.shape[1]
|
18 |
+
|
19 |
+
exp_coeff_pred = []
|
20 |
+
|
21 |
+
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
|
22 |
+
|
23 |
+
current_mel_input = mel_input[:,i:i+10]
|
24 |
+
|
25 |
+
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
|
26 |
+
ref = batch['ref'][:, :, :64][:, i:i+10]
|
27 |
+
ratio = batch['ratio_gt'][:, i:i+10] #bs T
|
28 |
+
|
29 |
+
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
|
30 |
+
|
31 |
+
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
|
32 |
+
|
33 |
+
exp_coeff_pred += [curr_exp_coeff_pred]
|
34 |
+
|
35 |
+
# BS x T x 64
|
36 |
+
results_dict = {
|
37 |
+
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
|
38 |
+
}
|
39 |
+
return results_dict
|
40 |
+
|
41 |
+
|
chat_anything/sad_talker/audio2exp_models/networks.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
self.use_act = use_act
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
out = self.conv_block(x)
|
18 |
+
if self.residual:
|
19 |
+
out += x
|
20 |
+
|
21 |
+
if self.use_act:
|
22 |
+
return self.act(out)
|
23 |
+
else:
|
24 |
+
return out
|
25 |
+
|
26 |
+
class SimpleWrapperV2(nn.Module):
|
27 |
+
def __init__(self) -> None:
|
28 |
+
super().__init__()
|
29 |
+
self.audio_encoder = nn.Sequential(
|
30 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
31 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
32 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
33 |
+
|
34 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
35 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
36 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
37 |
+
|
38 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
39 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
40 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
41 |
+
|
42 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
43 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
44 |
+
|
45 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
46 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
47 |
+
)
|
48 |
+
|
49 |
+
#### load the pre-trained audio_encoder
|
50 |
+
#self.audio_encoder = self.audio_encoder.to(device)
|
51 |
+
'''
|
52 |
+
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
|
53 |
+
state_dict = self.audio_encoder.state_dict()
|
54 |
+
|
55 |
+
for k,v in wav2lip_state_dict.items():
|
56 |
+
if 'audio_encoder' in k:
|
57 |
+
print('init:', k)
|
58 |
+
state_dict[k.replace('module.audio_encoder.', '')] = v
|
59 |
+
self.audio_encoder.load_state_dict(state_dict)
|
60 |
+
'''
|
61 |
+
|
62 |
+
self.mapping1 = nn.Linear(512+64+1, 64)
|
63 |
+
#self.mapping2 = nn.Linear(30, 64)
|
64 |
+
#nn.init.constant_(self.mapping1.weight, 0.)
|
65 |
+
nn.init.constant_(self.mapping1.bias, 0.)
|
66 |
+
|
67 |
+
def forward(self, x, ref, ratio):
|
68 |
+
x = self.audio_encoder(x).view(x.size(0), -1)
|
69 |
+
ref_reshape = ref.reshape(x.size(0), -1)
|
70 |
+
ratio = ratio.reshape(x.size(0), -1)
|
71 |
+
|
72 |
+
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
|
73 |
+
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
|
74 |
+
return out
|
chat_anything/sad_talker/audio2pose_models/audio2pose.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from chat_anything.sad_talker.audio2pose_models.cvae import CVAE
|
4 |
+
from chat_anything.sad_talker.audio2pose_models.discriminator import PoseSequenceDiscriminator
|
5 |
+
from chat_anything.sad_talker.audio2pose_models.audio_encoder import AudioEncoder
|
6 |
+
|
7 |
+
class Audio2Pose(nn.Module):
|
8 |
+
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
|
9 |
+
super().__init__()
|
10 |
+
self.cfg = cfg
|
11 |
+
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
12 |
+
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
|
16 |
+
self.audio_encoder.eval()
|
17 |
+
for param in self.audio_encoder.parameters():
|
18 |
+
param.requires_grad = False
|
19 |
+
|
20 |
+
self.netG = CVAE(cfg)
|
21 |
+
self.netD_motion = PoseSequenceDiscriminator(cfg)
|
22 |
+
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
|
26 |
+
batch = {}
|
27 |
+
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
|
28 |
+
batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
|
29 |
+
batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
|
30 |
+
batch['class'] = x['class'].squeeze(0).cuda() # bs
|
31 |
+
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
|
32 |
+
|
33 |
+
# forward
|
34 |
+
audio_emb_list = []
|
35 |
+
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
|
36 |
+
batch['audio_emb'] = audio_emb
|
37 |
+
batch = self.netG(batch)
|
38 |
+
|
39 |
+
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
|
40 |
+
pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
|
41 |
+
pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
|
42 |
+
|
43 |
+
batch['pose_pred'] = pose_pred
|
44 |
+
batch['pose_gt'] = pose_gt
|
45 |
+
|
46 |
+
return batch
|
47 |
+
|
48 |
+
def test(self, x):
|
49 |
+
|
50 |
+
batch = {}
|
51 |
+
ref = x['ref'] #bs 1 70
|
52 |
+
batch['ref'] = x['ref'][:,0,-6:]
|
53 |
+
batch['class'] = x['class']
|
54 |
+
bs = ref.shape[0]
|
55 |
+
|
56 |
+
indiv_mels= x['indiv_mels'] # bs T 1 80 16
|
57 |
+
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
|
58 |
+
num_frames = x['num_frames']
|
59 |
+
num_frames = int(num_frames) - 1
|
60 |
+
|
61 |
+
#
|
62 |
+
div = num_frames//self.seq_len
|
63 |
+
re = num_frames%self.seq_len
|
64 |
+
audio_emb_list = []
|
65 |
+
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
|
66 |
+
device=batch['ref'].device)]
|
67 |
+
|
68 |
+
for i in range(div):
|
69 |
+
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
70 |
+
batch['z'] = z
|
71 |
+
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
|
72 |
+
batch['audio_emb'] = audio_emb
|
73 |
+
batch = self.netG.test(batch)
|
74 |
+
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
|
75 |
+
|
76 |
+
if re != 0:
|
77 |
+
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
78 |
+
batch['z'] = z
|
79 |
+
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
|
80 |
+
if audio_emb.shape[1] != self.seq_len:
|
81 |
+
pad_dim = self.seq_len-audio_emb.shape[1]
|
82 |
+
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
|
83 |
+
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
|
84 |
+
batch['audio_emb'] = audio_emb
|
85 |
+
batch = self.netG.test(batch)
|
86 |
+
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
|
87 |
+
|
88 |
+
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
|
89 |
+
batch['pose_motion_pred'] = pose_motion_pred
|
90 |
+
|
91 |
+
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
|
92 |
+
|
93 |
+
batch['pose_pred'] = pose_pred
|
94 |
+
return batch
|
chat_anything/sad_talker/audio2pose_models/audio_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out = self.conv_block(x)
|
17 |
+
if self.residual:
|
18 |
+
out += x
|
19 |
+
return self.act(out)
|
20 |
+
|
21 |
+
class AudioEncoder(nn.Module):
|
22 |
+
def __init__(self, wav2lip_checkpoint, device):
|
23 |
+
super(AudioEncoder, self).__init__()
|
24 |
+
|
25 |
+
self.audio_encoder = nn.Sequential(
|
26 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
27 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
28 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
29 |
+
|
30 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
31 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
32 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
33 |
+
|
34 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
35 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
36 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
37 |
+
|
38 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
39 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
40 |
+
|
41 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
42 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
+
|
44 |
+
#### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
|
45 |
+
# wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
|
46 |
+
# state_dict = self.audio_encoder.state_dict()
|
47 |
+
|
48 |
+
# for k,v in wav2lip_state_dict.items():
|
49 |
+
# if 'audio_encoder' in k:
|
50 |
+
# state_dict[k.replace('module.audio_encoder.', '')] = v
|
51 |
+
# self.audio_encoder.load_state_dict(state_dict)
|
52 |
+
|
53 |
+
|
54 |
+
def forward(self, audio_sequences):
|
55 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
56 |
+
B = audio_sequences.size(0)
|
57 |
+
|
58 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
59 |
+
|
60 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
61 |
+
dim = audio_embedding.shape[1]
|
62 |
+
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
|
63 |
+
|
64 |
+
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
|
chat_anything/sad_talker/audio2pose_models/cvae.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from chat_anything.sad_talker.audio2pose_models.res_unet import ResUnet
|
5 |
+
|
6 |
+
def class2onehot(idx, class_num):
|
7 |
+
|
8 |
+
assert torch.max(idx).item() < class_num
|
9 |
+
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
|
10 |
+
onehot.scatter_(1, idx, 1)
|
11 |
+
return onehot
|
12 |
+
|
13 |
+
class CVAE(nn.Module):
|
14 |
+
def __init__(self, cfg):
|
15 |
+
super().__init__()
|
16 |
+
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
|
17 |
+
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
|
18 |
+
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
|
19 |
+
num_classes = cfg.DATASET.NUM_CLASSES
|
20 |
+
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
|
21 |
+
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
|
22 |
+
seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
23 |
+
|
24 |
+
self.latent_size = latent_size
|
25 |
+
|
26 |
+
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
|
27 |
+
audio_emb_in_size, audio_emb_out_size, seq_len)
|
28 |
+
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
|
29 |
+
audio_emb_in_size, audio_emb_out_size, seq_len)
|
30 |
+
def reparameterize(self, mu, logvar):
|
31 |
+
std = torch.exp(0.5 * logvar)
|
32 |
+
eps = torch.randn_like(std)
|
33 |
+
return mu + eps * std
|
34 |
+
|
35 |
+
def forward(self, batch):
|
36 |
+
batch = self.encoder(batch)
|
37 |
+
mu = batch['mu']
|
38 |
+
logvar = batch['logvar']
|
39 |
+
z = self.reparameterize(mu, logvar)
|
40 |
+
batch['z'] = z
|
41 |
+
return self.decoder(batch)
|
42 |
+
|
43 |
+
def test(self, batch):
|
44 |
+
'''
|
45 |
+
class_id = batch['class']
|
46 |
+
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
|
47 |
+
batch['z'] = z
|
48 |
+
'''
|
49 |
+
return self.decoder(batch)
|
50 |
+
|
51 |
+
class ENCODER(nn.Module):
|
52 |
+
def __init__(self, layer_sizes, latent_size, num_classes,
|
53 |
+
audio_emb_in_size, audio_emb_out_size, seq_len):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.resunet = ResUnet()
|
57 |
+
self.num_classes = num_classes
|
58 |
+
self.seq_len = seq_len
|
59 |
+
|
60 |
+
self.MLP = nn.Sequential()
|
61 |
+
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
|
62 |
+
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
|
63 |
+
self.MLP.add_module(
|
64 |
+
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
65 |
+
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
66 |
+
|
67 |
+
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
|
68 |
+
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
|
69 |
+
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
70 |
+
|
71 |
+
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
72 |
+
|
73 |
+
def forward(self, batch):
|
74 |
+
class_id = batch['class']
|
75 |
+
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
|
76 |
+
ref = batch['ref'] #bs 6
|
77 |
+
bs = pose_motion_gt.shape[0]
|
78 |
+
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
79 |
+
|
80 |
+
#pose encode
|
81 |
+
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
|
82 |
+
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
|
83 |
+
|
84 |
+
#audio mapping
|
85 |
+
print(audio_in.shape)
|
86 |
+
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
87 |
+
audio_out = audio_out.reshape(bs, -1)
|
88 |
+
|
89 |
+
class_bias = self.classbias[class_id] #bs latent_size
|
90 |
+
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
|
91 |
+
x_out = self.MLP(x_in)
|
92 |
+
|
93 |
+
mu = self.linear_means(x_out)
|
94 |
+
logvar = self.linear_means(x_out) #bs latent_size
|
95 |
+
|
96 |
+
batch.update({'mu':mu, 'logvar':logvar})
|
97 |
+
return batch
|
98 |
+
|
99 |
+
class DECODER(nn.Module):
|
100 |
+
def __init__(self, layer_sizes, latent_size, num_classes,
|
101 |
+
audio_emb_in_size, audio_emb_out_size, seq_len):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.resunet = ResUnet()
|
105 |
+
self.num_classes = num_classes
|
106 |
+
self.seq_len = seq_len
|
107 |
+
|
108 |
+
self.MLP = nn.Sequential()
|
109 |
+
input_size = latent_size + seq_len*audio_emb_out_size + 6
|
110 |
+
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
|
111 |
+
self.MLP.add_module(
|
112 |
+
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
113 |
+
if i+1 < len(layer_sizes):
|
114 |
+
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
115 |
+
else:
|
116 |
+
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
|
117 |
+
|
118 |
+
self.pose_linear = nn.Linear(6, 6)
|
119 |
+
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
120 |
+
|
121 |
+
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
122 |
+
|
123 |
+
def forward(self, batch):
|
124 |
+
|
125 |
+
z = batch['z'] #bs latent_size
|
126 |
+
bs = z.shape[0]
|
127 |
+
class_id = batch['class']
|
128 |
+
ref = batch['ref'] #bs 6
|
129 |
+
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
130 |
+
#print('audio_in: ', audio_in[:, :, :10])
|
131 |
+
|
132 |
+
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
133 |
+
#print('audio_out: ', audio_out[:, :, :10])
|
134 |
+
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
|
135 |
+
class_bias = self.classbias[class_id] #bs latent_size
|
136 |
+
|
137 |
+
z = z + class_bias
|
138 |
+
x_in = torch.cat([ref, z, audio_out], dim=-1)
|
139 |
+
x_out = self.MLP(x_in) # bs layer_sizes[-1]
|
140 |
+
x_out = x_out.reshape((bs, self.seq_len, -1))
|
141 |
+
|
142 |
+
#print('x_out: ', x_out)
|
143 |
+
|
144 |
+
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
|
145 |
+
|
146 |
+
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
|
147 |
+
|
148 |
+
batch.update({'pose_motion_pred':pose_motion_pred})
|
149 |
+
return batch
|
chat_anything/sad_talker/audio2pose_models/discriminator.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class ConvNormRelu(nn.Module):
|
6 |
+
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
|
7 |
+
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
|
8 |
+
super().__init__()
|
9 |
+
if kernel_size is None:
|
10 |
+
if downsample:
|
11 |
+
kernel_size, stride, padding = 4, 2, 1
|
12 |
+
else:
|
13 |
+
kernel_size, stride, padding = 3, 1, 1
|
14 |
+
|
15 |
+
if conv_type == '2d':
|
16 |
+
self.conv = nn.Conv2d(
|
17 |
+
in_channels,
|
18 |
+
out_channels,
|
19 |
+
kernel_size,
|
20 |
+
stride,
|
21 |
+
padding,
|
22 |
+
bias=False,
|
23 |
+
)
|
24 |
+
if norm == 'BN':
|
25 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
26 |
+
elif norm == 'IN':
|
27 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
elif conv_type == '1d':
|
31 |
+
self.conv = nn.Conv1d(
|
32 |
+
in_channels,
|
33 |
+
out_channels,
|
34 |
+
kernel_size,
|
35 |
+
stride,
|
36 |
+
padding,
|
37 |
+
bias=False,
|
38 |
+
)
|
39 |
+
if norm == 'BN':
|
40 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
41 |
+
elif norm == 'IN':
|
42 |
+
self.norm = nn.InstanceNorm1d(out_channels)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError
|
45 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
46 |
+
|
47 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = self.conv(x)
|
51 |
+
if isinstance(self.norm, nn.InstanceNorm1d):
|
52 |
+
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
|
53 |
+
else:
|
54 |
+
x = self.norm(x)
|
55 |
+
x = self.act(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class PoseSequenceDiscriminator(nn.Module):
|
60 |
+
def __init__(self, cfg):
|
61 |
+
super().__init__()
|
62 |
+
self.cfg = cfg
|
63 |
+
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
|
64 |
+
|
65 |
+
self.seq = nn.Sequential(
|
66 |
+
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
|
67 |
+
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
|
68 |
+
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
|
69 |
+
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
|
74 |
+
x = self.seq(x)
|
75 |
+
x = x.squeeze(1)
|
76 |
+
return x
|
chat_anything/sad_talker/audio2pose_models/networks.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ResidualConv(nn.Module):
|
6 |
+
def __init__(self, input_dim, output_dim, stride, padding):
|
7 |
+
super(ResidualConv, self).__init__()
|
8 |
+
|
9 |
+
self.conv_block = nn.Sequential(
|
10 |
+
nn.BatchNorm2d(input_dim),
|
11 |
+
nn.ReLU(),
|
12 |
+
nn.Conv2d(
|
13 |
+
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
|
14 |
+
),
|
15 |
+
nn.BatchNorm2d(output_dim),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
18 |
+
)
|
19 |
+
self.conv_skip = nn.Sequential(
|
20 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
|
21 |
+
nn.BatchNorm2d(output_dim),
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
|
26 |
+
return self.conv_block(x) + self.conv_skip(x)
|
27 |
+
|
28 |
+
|
29 |
+
class Upsample(nn.Module):
|
30 |
+
def __init__(self, input_dim, output_dim, kernel, stride):
|
31 |
+
super(Upsample, self).__init__()
|
32 |
+
|
33 |
+
self.upsample = nn.ConvTranspose2d(
|
34 |
+
input_dim, output_dim, kernel_size=kernel, stride=stride
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.upsample(x)
|
39 |
+
|
40 |
+
|
41 |
+
class Squeeze_Excite_Block(nn.Module):
|
42 |
+
def __init__(self, channel, reduction=16):
|
43 |
+
super(Squeeze_Excite_Block, self).__init__()
|
44 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
45 |
+
self.fc = nn.Sequential(
|
46 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
49 |
+
nn.Sigmoid(),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
b, c, _, _ = x.size()
|
54 |
+
y = self.avg_pool(x).view(b, c)
|
55 |
+
y = self.fc(y).view(b, c, 1, 1)
|
56 |
+
return x * y.expand_as(x)
|
57 |
+
|
58 |
+
|
59 |
+
class ASPP(nn.Module):
|
60 |
+
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
|
61 |
+
super(ASPP, self).__init__()
|
62 |
+
|
63 |
+
self.aspp_block1 = nn.Sequential(
|
64 |
+
nn.Conv2d(
|
65 |
+
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
|
66 |
+
),
|
67 |
+
nn.ReLU(inplace=True),
|
68 |
+
nn.BatchNorm2d(out_dims),
|
69 |
+
)
|
70 |
+
self.aspp_block2 = nn.Sequential(
|
71 |
+
nn.Conv2d(
|
72 |
+
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
|
73 |
+
),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
nn.BatchNorm2d(out_dims),
|
76 |
+
)
|
77 |
+
self.aspp_block3 = nn.Sequential(
|
78 |
+
nn.Conv2d(
|
79 |
+
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
|
80 |
+
),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
nn.BatchNorm2d(out_dims),
|
83 |
+
)
|
84 |
+
|
85 |
+
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
|
86 |
+
self._init_weights()
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x1 = self.aspp_block1(x)
|
90 |
+
x2 = self.aspp_block2(x)
|
91 |
+
x3 = self.aspp_block3(x)
|
92 |
+
out = torch.cat([x1, x2, x3], dim=1)
|
93 |
+
return self.output(out)
|
94 |
+
|
95 |
+
def _init_weights(self):
|
96 |
+
for m in self.modules():
|
97 |
+
if isinstance(m, nn.Conv2d):
|
98 |
+
nn.init.kaiming_normal_(m.weight)
|
99 |
+
elif isinstance(m, nn.BatchNorm2d):
|
100 |
+
m.weight.data.fill_(1)
|
101 |
+
m.bias.data.zero_()
|
102 |
+
|
103 |
+
|
104 |
+
class Upsample_(nn.Module):
|
105 |
+
def __init__(self, scale=2):
|
106 |
+
super(Upsample_, self).__init__()
|
107 |
+
|
108 |
+
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
return self.upsample(x)
|
112 |
+
|
113 |
+
|
114 |
+
class AttentionBlock(nn.Module):
|
115 |
+
def __init__(self, input_encoder, input_decoder, output_dim):
|
116 |
+
super(AttentionBlock, self).__init__()
|
117 |
+
|
118 |
+
self.conv_encoder = nn.Sequential(
|
119 |
+
nn.BatchNorm2d(input_encoder),
|
120 |
+
nn.ReLU(),
|
121 |
+
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
|
122 |
+
nn.MaxPool2d(2, 2),
|
123 |
+
)
|
124 |
+
|
125 |
+
self.conv_decoder = nn.Sequential(
|
126 |
+
nn.BatchNorm2d(input_decoder),
|
127 |
+
nn.ReLU(),
|
128 |
+
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
|
129 |
+
)
|
130 |
+
|
131 |
+
self.conv_attn = nn.Sequential(
|
132 |
+
nn.BatchNorm2d(output_dim),
|
133 |
+
nn.ReLU(),
|
134 |
+
nn.Conv2d(output_dim, 1, 1),
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x1, x2):
|
138 |
+
out = self.conv_encoder(x1) + self.conv_decoder(x2)
|
139 |
+
out = self.conv_attn(out)
|
140 |
+
return out * x2
|
chat_anything/sad_talker/audio2pose_models/res_unet.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from chat_anything.sad_talker.audio2pose_models.networks import ResidualConv, Upsample
|
4 |
+
|
5 |
+
|
6 |
+
class ResUnet(nn.Module):
|
7 |
+
def __init__(self, channel=1, filters=[32, 64, 128, 256]):
|
8 |
+
super(ResUnet, self).__init__()
|
9 |
+
|
10 |
+
self.input_layer = nn.Sequential(
|
11 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
|
12 |
+
nn.BatchNorm2d(filters[0]),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
|
15 |
+
)
|
16 |
+
self.input_skip = nn.Sequential(
|
17 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
|
18 |
+
)
|
19 |
+
|
20 |
+
self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
|
21 |
+
self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
|
22 |
+
|
23 |
+
self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
|
24 |
+
|
25 |
+
self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
|
26 |
+
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
|
27 |
+
|
28 |
+
self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
|
29 |
+
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
|
30 |
+
|
31 |
+
self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
|
32 |
+
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
|
33 |
+
|
34 |
+
self.output_layer = nn.Sequential(
|
35 |
+
nn.Conv2d(filters[0], 1, 1, 1),
|
36 |
+
nn.Sigmoid(),
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
# Encode
|
41 |
+
x1 = self.input_layer(x) + self.input_skip(x)
|
42 |
+
x2 = self.residual_conv_1(x1)
|
43 |
+
x3 = self.residual_conv_2(x2)
|
44 |
+
# Bridge
|
45 |
+
x4 = self.bridge(x3)
|
46 |
+
|
47 |
+
# Decode
|
48 |
+
x4 = self.upsample_1(x4)
|
49 |
+
x5 = torch.cat([x4, x3], dim=1)
|
50 |
+
|
51 |
+
x6 = self.up_residual_conv1(x5)
|
52 |
+
|
53 |
+
x6 = self.upsample_2(x6)
|
54 |
+
x7 = torch.cat([x6, x2], dim=1)
|
55 |
+
|
56 |
+
x8 = self.up_residual_conv2(x7)
|
57 |
+
|
58 |
+
x8 = self.upsample_3(x8)
|
59 |
+
x9 = torch.cat([x8, x1], dim=1)
|
60 |
+
|
61 |
+
x10 = self.up_residual_conv3(x9)
|
62 |
+
|
63 |
+
output = self.output_layer(x10)
|
64 |
+
|
65 |
+
return output
|
chat_anything/sad_talker/config/auido2exp.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET:
|
2 |
+
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
|
3 |
+
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
|
4 |
+
TRAIN_BATCH_SIZE: 32
|
5 |
+
EVAL_BATCH_SIZE: 32
|
6 |
+
EXP: True
|
7 |
+
EXP_DIM: 64
|
8 |
+
FRAME_LEN: 32
|
9 |
+
COEFF_LEN: 73
|
10 |
+
NUM_CLASSES: 46
|
11 |
+
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
|
12 |
+
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
|
13 |
+
LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
|
14 |
+
DEBUG: True
|
15 |
+
NUM_REPEATS: 2
|
16 |
+
T: 40
|
17 |
+
|
18 |
+
|
19 |
+
MODEL:
|
20 |
+
FRAMEWORK: V2
|
21 |
+
AUDIOENCODER:
|
22 |
+
LEAKY_RELU: True
|
23 |
+
NORM: 'IN'
|
24 |
+
DISCRIMINATOR:
|
25 |
+
LEAKY_RELU: False
|
26 |
+
INPUT_CHANNELS: 6
|
27 |
+
CVAE:
|
28 |
+
AUDIO_EMB_IN_SIZE: 512
|
29 |
+
AUDIO_EMB_OUT_SIZE: 128
|
30 |
+
SEQ_LEN: 32
|
31 |
+
LATENT_SIZE: 256
|
32 |
+
ENCODER_LAYER_SIZES: [192, 1024]
|
33 |
+
DECODER_LAYER_SIZES: [1024, 192]
|
34 |
+
|
35 |
+
|
36 |
+
TRAIN:
|
37 |
+
MAX_EPOCH: 300
|
38 |
+
GENERATOR:
|
39 |
+
LR: 2.0e-5
|
40 |
+
DISCRIMINATOR:
|
41 |
+
LR: 1.0e-5
|
42 |
+
LOSS:
|
43 |
+
W_FEAT: 0
|
44 |
+
W_COEFF_EXP: 2
|
45 |
+
W_LM: 1.0e-2
|
46 |
+
W_LM_MOUTH: 0
|
47 |
+
W_REG: 0
|
48 |
+
W_SYNC: 0
|
49 |
+
W_COLOR: 0
|
50 |
+
W_EXPRESSION: 0
|
51 |
+
W_LIPREADING: 0.01
|
52 |
+
W_LIPREADING_VV: 0
|
53 |
+
W_EYE_BLINK: 4
|
54 |
+
|
55 |
+
TAG:
|
56 |
+
NAME: small_dataset
|
57 |
+
|
58 |
+
|
chat_anything/sad_talker/config/auido2pose.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET:
|
2 |
+
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
|
3 |
+
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
|
4 |
+
TRAIN_BATCH_SIZE: 64
|
5 |
+
EVAL_BATCH_SIZE: 1
|
6 |
+
EXP: True
|
7 |
+
EXP_DIM: 64
|
8 |
+
FRAME_LEN: 32
|
9 |
+
COEFF_LEN: 73
|
10 |
+
NUM_CLASSES: 46
|
11 |
+
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
|
12 |
+
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
|
13 |
+
DEBUG: True
|
14 |
+
|
15 |
+
|
16 |
+
MODEL:
|
17 |
+
AUDIOENCODER:
|
18 |
+
LEAKY_RELU: True
|
19 |
+
NORM: 'IN'
|
20 |
+
DISCRIMINATOR:
|
21 |
+
LEAKY_RELU: False
|
22 |
+
INPUT_CHANNELS: 6
|
23 |
+
CVAE:
|
24 |
+
AUDIO_EMB_IN_SIZE: 512
|
25 |
+
AUDIO_EMB_OUT_SIZE: 6
|
26 |
+
SEQ_LEN: 32
|
27 |
+
LATENT_SIZE: 64
|
28 |
+
ENCODER_LAYER_SIZES: [192, 128]
|
29 |
+
DECODER_LAYER_SIZES: [128, 192]
|
30 |
+
|
31 |
+
|
32 |
+
TRAIN:
|
33 |
+
MAX_EPOCH: 150
|
34 |
+
GENERATOR:
|
35 |
+
LR: 1.0e-4
|
36 |
+
DISCRIMINATOR:
|
37 |
+
LR: 1.0e-4
|
38 |
+
LOSS:
|
39 |
+
LAMBDA_REG: 1
|
40 |
+
LAMBDA_LANDMARKS: 0
|
41 |
+
LAMBDA_VERTICES: 0
|
42 |
+
LAMBDA_GAN_MOTION: 0.7
|
43 |
+
LAMBDA_GAN_COEFF: 0
|
44 |
+
LAMBDA_KL: 1
|
45 |
+
|
46 |
+
TAG:
|
47 |
+
NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
|
48 |
+
|
49 |
+
|
chat_anything/sad_talker/config/facerender.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
common_params:
|
3 |
+
num_kp: 15
|
4 |
+
image_channel: 3
|
5 |
+
feature_channel: 32
|
6 |
+
estimate_jacobian: False # True
|
7 |
+
kp_detector_params:
|
8 |
+
temperature: 0.1
|
9 |
+
block_expansion: 32
|
10 |
+
max_features: 1024
|
11 |
+
scale_factor: 0.25 # 0.25
|
12 |
+
num_blocks: 5
|
13 |
+
reshape_channel: 16384 # 16384 = 1024 * 16
|
14 |
+
reshape_depth: 16
|
15 |
+
he_estimator_params:
|
16 |
+
block_expansion: 64
|
17 |
+
max_features: 2048
|
18 |
+
num_bins: 66
|
19 |
+
generator_params:
|
20 |
+
block_expansion: 64
|
21 |
+
max_features: 512
|
22 |
+
num_down_blocks: 2
|
23 |
+
reshape_channel: 32
|
24 |
+
reshape_depth: 16 # 512 = 32 * 16
|
25 |
+
num_resblocks: 6
|
26 |
+
estimate_occlusion_map: True
|
27 |
+
dense_motion_params:
|
28 |
+
block_expansion: 32
|
29 |
+
max_features: 1024
|
30 |
+
num_blocks: 5
|
31 |
+
reshape_depth: 16
|
32 |
+
compress: 4
|
33 |
+
discriminator_params:
|
34 |
+
scales: [1]
|
35 |
+
block_expansion: 32
|
36 |
+
max_features: 512
|
37 |
+
num_blocks: 4
|
38 |
+
sn: True
|
39 |
+
mapping_params:
|
40 |
+
coeff_nc: 70
|
41 |
+
descriptor_nc: 1024
|
42 |
+
layer: 3
|
43 |
+
num_kp: 15
|
44 |
+
num_bins: 66
|
45 |
+
|
chat_anything/sad_talker/config/facerender_still.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
common_params:
|
3 |
+
num_kp: 15
|
4 |
+
image_channel: 3
|
5 |
+
feature_channel: 32
|
6 |
+
estimate_jacobian: False # True
|
7 |
+
kp_detector_params:
|
8 |
+
temperature: 0.1
|
9 |
+
block_expansion: 32
|
10 |
+
max_features: 1024
|
11 |
+
scale_factor: 0.25 # 0.25
|
12 |
+
num_blocks: 5
|
13 |
+
reshape_channel: 16384 # 16384 = 1024 * 16
|
14 |
+
reshape_depth: 16
|
15 |
+
he_estimator_params:
|
16 |
+
block_expansion: 64
|
17 |
+
max_features: 2048
|
18 |
+
num_bins: 66
|
19 |
+
generator_params:
|
20 |
+
block_expansion: 64
|
21 |
+
max_features: 512
|
22 |
+
num_down_blocks: 2
|
23 |
+
reshape_channel: 32
|
24 |
+
reshape_depth: 16 # 512 = 32 * 16
|
25 |
+
num_resblocks: 6
|
26 |
+
estimate_occlusion_map: True
|
27 |
+
dense_motion_params:
|
28 |
+
block_expansion: 32
|
29 |
+
max_features: 1024
|
30 |
+
num_blocks: 5
|
31 |
+
reshape_depth: 16
|
32 |
+
compress: 4
|
33 |
+
discriminator_params:
|
34 |
+
scales: [1]
|
35 |
+
block_expansion: 32
|
36 |
+
max_features: 512
|
37 |
+
num_blocks: 4
|
38 |
+
sn: True
|
39 |
+
mapping_params:
|
40 |
+
coeff_nc: 73
|
41 |
+
descriptor_nc: 1024
|
42 |
+
layer: 3
|
43 |
+
num_kp: 15
|
44 |
+
num_bins: 66
|
45 |
+
|
chat_anything/sad_talker/config/similarity_Lm3D_all.mat
ADDED
Binary file (994 Bytes). View file
|
|
chat_anything/sad_talker/face3d/data/__init__.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import numpy as np
|
14 |
+
import importlib
|
15 |
+
import torch.utils.data
|
16 |
+
from face3d.data.base_dataset import BaseDataset
|
17 |
+
|
18 |
+
|
19 |
+
def find_dataset_using_name(dataset_name):
|
20 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
21 |
+
|
22 |
+
In the file, the class called DatasetNameDataset() will
|
23 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
24 |
+
and it is case-insensitive.
|
25 |
+
"""
|
26 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
27 |
+
datasetlib = importlib.import_module(dataset_filename)
|
28 |
+
|
29 |
+
dataset = None
|
30 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
31 |
+
for name, cls in datasetlib.__dict__.items():
|
32 |
+
if name.lower() == target_dataset_name.lower() \
|
33 |
+
and issubclass(cls, BaseDataset):
|
34 |
+
dataset = cls
|
35 |
+
|
36 |
+
if dataset is None:
|
37 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
38 |
+
|
39 |
+
return dataset
|
40 |
+
|
41 |
+
|
42 |
+
def get_option_setter(dataset_name):
|
43 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
44 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
45 |
+
return dataset_class.modify_commandline_options
|
46 |
+
|
47 |
+
|
48 |
+
def create_dataset(opt, rank=0):
|
49 |
+
"""Create a dataset given the option.
|
50 |
+
|
51 |
+
This function wraps the class CustomDatasetDataLoader.
|
52 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
53 |
+
|
54 |
+
Example:
|
55 |
+
>>> from data import create_dataset
|
56 |
+
>>> dataset = create_dataset(opt)
|
57 |
+
"""
|
58 |
+
data_loader = CustomDatasetDataLoader(opt, rank=rank)
|
59 |
+
dataset = data_loader.load_data()
|
60 |
+
return dataset
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt, rank=0):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
self.sampler = None
|
75 |
+
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
|
76 |
+
if opt.use_ddp and opt.isTrain:
|
77 |
+
world_size = opt.world_size
|
78 |
+
self.sampler = torch.utils.data.distributed.DistributedSampler(
|
79 |
+
self.dataset,
|
80 |
+
num_replicas=world_size,
|
81 |
+
rank=rank,
|
82 |
+
shuffle=not opt.serial_batches
|
83 |
+
)
|
84 |
+
self.dataloader = torch.utils.data.DataLoader(
|
85 |
+
self.dataset,
|
86 |
+
sampler=self.sampler,
|
87 |
+
num_workers=int(opt.num_threads / world_size),
|
88 |
+
batch_size=int(opt.batch_size / world_size),
|
89 |
+
drop_last=True)
|
90 |
+
else:
|
91 |
+
self.dataloader = torch.utils.data.DataLoader(
|
92 |
+
self.dataset,
|
93 |
+
batch_size=opt.batch_size,
|
94 |
+
shuffle=(not opt.serial_batches) and opt.isTrain,
|
95 |
+
num_workers=int(opt.num_threads),
|
96 |
+
drop_last=True
|
97 |
+
)
|
98 |
+
|
99 |
+
def set_epoch(self, epoch):
|
100 |
+
self.dataset.current_epoch = epoch
|
101 |
+
if self.sampler is not None:
|
102 |
+
self.sampler.set_epoch(epoch)
|
103 |
+
|
104 |
+
def load_data(self):
|
105 |
+
return self
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
"""Return the number of data in the dataset"""
|
109 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
110 |
+
|
111 |
+
def __iter__(self):
|
112 |
+
"""Return a batch of data"""
|
113 |
+
for i, data in enumerate(self.dataloader):
|
114 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
115 |
+
break
|
116 |
+
yield data
|
chat_anything/sad_talker/face3d/data/base_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
# self.root = opt.dataroot
|
31 |
+
self.current_epoch = 0
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def modify_commandline_options(parser, is_train):
|
35 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
parser -- original option parser
|
39 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
the modified parser.
|
43 |
+
"""
|
44 |
+
return parser
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def __len__(self):
|
48 |
+
"""Return the total number of images in the dataset."""
|
49 |
+
return 0
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def __getitem__(self, index):
|
53 |
+
"""Return a data point and its metadata information.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
index - - a random integer for data indexing
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
60 |
+
"""
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
def get_transform(grayscale=False):
|
65 |
+
transform_list = []
|
66 |
+
if grayscale:
|
67 |
+
transform_list.append(transforms.Grayscale(1))
|
68 |
+
transform_list += [transforms.ToTensor()]
|
69 |
+
return transforms.Compose(transform_list)
|
70 |
+
|
71 |
+
def get_affine_mat(opt, size):
|
72 |
+
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
|
73 |
+
w, h = size
|
74 |
+
|
75 |
+
if 'shift' in opt.preprocess:
|
76 |
+
shift_pixs = int(opt.shift_pixs)
|
77 |
+
shift_x = random.randint(-shift_pixs, shift_pixs)
|
78 |
+
shift_y = random.randint(-shift_pixs, shift_pixs)
|
79 |
+
if 'scale' in opt.preprocess:
|
80 |
+
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
|
81 |
+
if 'rot' in opt.preprocess:
|
82 |
+
rot_angle = opt.rot_angle * (2 * random.random() - 1)
|
83 |
+
rot_rad = -rot_angle * np.pi/180
|
84 |
+
if 'flip' in opt.preprocess:
|
85 |
+
flip = random.random() > 0.5
|
86 |
+
|
87 |
+
shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
|
88 |
+
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
|
89 |
+
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
|
90 |
+
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
|
91 |
+
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
|
92 |
+
shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
|
93 |
+
|
94 |
+
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
|
95 |
+
affine_inv = np.linalg.inv(affine)
|
96 |
+
return affine, affine_inv, flip
|
97 |
+
|
98 |
+
def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
|
99 |
+
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
|
100 |
+
|
101 |
+
def apply_lm_affine(landmark, affine, flip, size):
|
102 |
+
_, h = size
|
103 |
+
lm = landmark.copy()
|
104 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
105 |
+
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
|
106 |
+
lm = lm @ np.transpose(affine)
|
107 |
+
lm[:, :2] = lm[:, :2] / lm[:, 2:]
|
108 |
+
lm = lm[:, :2]
|
109 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
110 |
+
if flip:
|
111 |
+
lm_ = lm.copy()
|
112 |
+
lm_[:17] = lm[16::-1]
|
113 |
+
lm_[17:22] = lm[26:21:-1]
|
114 |
+
lm_[22:27] = lm[21:16:-1]
|
115 |
+
lm_[31:36] = lm[35:30:-1]
|
116 |
+
lm_[36:40] = lm[45:41:-1]
|
117 |
+
lm_[40:42] = lm[47:45:-1]
|
118 |
+
lm_[42:46] = lm[39:35:-1]
|
119 |
+
lm_[46:48] = lm[41:39:-1]
|
120 |
+
lm_[48:55] = lm[54:47:-1]
|
121 |
+
lm_[55:60] = lm[59:54:-1]
|
122 |
+
lm_[60:65] = lm[64:59:-1]
|
123 |
+
lm_[65:68] = lm[67:64:-1]
|
124 |
+
lm = lm_
|
125 |
+
return lm
|
chat_anything/sad_talker/face3d/data/flist_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os.path
|
5 |
+
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
|
6 |
+
from data.image_folder import make_dataset
|
7 |
+
from PIL import Image
|
8 |
+
import random
|
9 |
+
import util.util as util
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
+
import torch
|
13 |
+
from scipy.io import loadmat, savemat
|
14 |
+
import pickle
|
15 |
+
from util.preprocess import align_img, estimate_norm
|
16 |
+
from util.load_mats import load_lm3d
|
17 |
+
|
18 |
+
|
19 |
+
def default_flist_reader(flist):
|
20 |
+
"""
|
21 |
+
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
|
22 |
+
"""
|
23 |
+
imlist = []
|
24 |
+
with open(flist, 'r') as rf:
|
25 |
+
for line in rf.readlines():
|
26 |
+
impath = line.strip()
|
27 |
+
imlist.append(impath)
|
28 |
+
|
29 |
+
return imlist
|
30 |
+
|
31 |
+
def jason_flist_reader(flist):
|
32 |
+
with open(flist, 'r') as fp:
|
33 |
+
info = json.load(fp)
|
34 |
+
return info
|
35 |
+
|
36 |
+
def parse_label(label):
|
37 |
+
return torch.tensor(np.array(label).astype(np.float32))
|
38 |
+
|
39 |
+
|
40 |
+
class FlistDataset(BaseDataset):
|
41 |
+
"""
|
42 |
+
It requires one directories to host training images '/path/to/data/train'
|
43 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, opt):
|
47 |
+
"""Initialize this dataset class.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
51 |
+
"""
|
52 |
+
BaseDataset.__init__(self, opt)
|
53 |
+
|
54 |
+
self.lm3d_std = load_lm3d(opt.bfm_folder)
|
55 |
+
|
56 |
+
msk_names = default_flist_reader(opt.flist)
|
57 |
+
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
|
58 |
+
|
59 |
+
self.size = len(self.msk_paths)
|
60 |
+
self.opt = opt
|
61 |
+
|
62 |
+
self.name = 'train' if opt.isTrain else 'val'
|
63 |
+
if '_' in opt.flist:
|
64 |
+
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
|
65 |
+
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
"""Return a data point and its metadata information.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
index (int) -- a random integer for data indexing
|
72 |
+
|
73 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
74 |
+
img (tensor) -- an image in the input domain
|
75 |
+
msk (tensor) -- its corresponding attention mask
|
76 |
+
lm (tensor) -- its corresponding 3d landmarks
|
77 |
+
im_paths (str) -- image paths
|
78 |
+
aug_flag (bool) -- a flag used to tell whether its raw or augmented
|
79 |
+
"""
|
80 |
+
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
|
81 |
+
img_path = msk_path.replace('mask/', '')
|
82 |
+
lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
|
83 |
+
|
84 |
+
raw_img = Image.open(img_path).convert('RGB')
|
85 |
+
raw_msk = Image.open(msk_path).convert('RGB')
|
86 |
+
raw_lm = np.loadtxt(lm_path).astype(np.float32)
|
87 |
+
|
88 |
+
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
|
89 |
+
|
90 |
+
aug_flag = self.opt.use_aug and self.opt.isTrain
|
91 |
+
if aug_flag:
|
92 |
+
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
|
93 |
+
|
94 |
+
_, H = img.size
|
95 |
+
M = estimate_norm(lm, H)
|
96 |
+
transform = get_transform()
|
97 |
+
img_tensor = transform(img)
|
98 |
+
msk_tensor = transform(msk)[:1, ...]
|
99 |
+
lm_tensor = parse_label(lm)
|
100 |
+
M_tensor = parse_label(M)
|
101 |
+
|
102 |
+
|
103 |
+
return {'imgs': img_tensor,
|
104 |
+
'lms': lm_tensor,
|
105 |
+
'msks': msk_tensor,
|
106 |
+
'M': M_tensor,
|
107 |
+
'im_paths': img_path,
|
108 |
+
'aug_flag': aug_flag,
|
109 |
+
'dataset': self.name}
|
110 |
+
|
111 |
+
def _augmentation(self, img, lm, opt, msk=None):
|
112 |
+
affine, affine_inv, flip = get_affine_mat(opt, img.size)
|
113 |
+
img = apply_img_affine(img, affine_inv)
|
114 |
+
lm = apply_lm_affine(lm, affine, flip, img.size)
|
115 |
+
if msk is not None:
|
116 |
+
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
|
117 |
+
return img, lm, msk
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
"""Return the total number of images in the dataset.
|
124 |
+
"""
|
125 |
+
return self.size
|
chat_anything/sad_talker/face3d/data/image_folder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = [
|
14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
16 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
return images[:min(max_dataset_size, len(images))]
|
34 |
+
|
35 |
+
|
36 |
+
def default_loader(path):
|
37 |
+
return Image.open(path).convert('RGB')
|
38 |
+
|
39 |
+
|
40 |
+
class ImageFolder(data.Dataset):
|
41 |
+
|
42 |
+
def __init__(self, root, transform=None, return_paths=False,
|
43 |
+
loader=default_loader):
|
44 |
+
imgs = make_dataset(root)
|
45 |
+
if len(imgs) == 0:
|
46 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
47 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.imgs = imgs
|
51 |
+
self.transform = transform
|
52 |
+
self.return_paths = return_paths
|
53 |
+
self.loader = loader
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path = self.imgs[index]
|
57 |
+
img = self.loader(path)
|
58 |
+
if self.transform is not None:
|
59 |
+
img = self.transform(img)
|
60 |
+
if self.return_paths:
|
61 |
+
return img, path
|
62 |
+
else:
|
63 |
+
return img
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.imgs)
|
chat_anything/sad_talker/face3d/data/template_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset class template
|
2 |
+
|
3 |
+
This module provides a template for users to implement custom datasets.
|
4 |
+
You can specify '--dataset_mode template' to use this dataset.
|
5 |
+
The class name should be consistent with both the filename and its dataset_mode option.
|
6 |
+
The filename should be <dataset_mode>_dataset.py
|
7 |
+
The class name should be <Dataset_mode>Dataset.py
|
8 |
+
You need to implement the following functions:
|
9 |
+
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
10 |
+
-- <__init__>: Initialize this dataset class.
|
11 |
+
-- <__getitem__>: Return a data point and its metadata information.
|
12 |
+
-- <__len__>: Return the number of images.
|
13 |
+
"""
|
14 |
+
from data.base_dataset import BaseDataset, get_transform
|
15 |
+
# from data.image_folder import make_dataset
|
16 |
+
# from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class TemplateDataset(BaseDataset):
|
20 |
+
"""A template dataset class for you to implement custom datasets."""
|
21 |
+
@staticmethod
|
22 |
+
def modify_commandline_options(parser, is_train):
|
23 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
parser -- original option parser
|
27 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
the modified parser.
|
31 |
+
"""
|
32 |
+
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
|
33 |
+
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
34 |
+
return parser
|
35 |
+
|
36 |
+
def __init__(self, opt):
|
37 |
+
"""Initialize this dataset class.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
41 |
+
|
42 |
+
A few things can be done here.
|
43 |
+
- save the options (have been done in BaseDataset)
|
44 |
+
- get image paths and meta information of the dataset.
|
45 |
+
- define the image transformation.
|
46 |
+
"""
|
47 |
+
# save the option and dataset root
|
48 |
+
BaseDataset.__init__(self, opt)
|
49 |
+
# get the image paths of your dataset;
|
50 |
+
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
51 |
+
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
52 |
+
self.transform = get_transform(opt)
|
53 |
+
|
54 |
+
def __getitem__(self, index):
|
55 |
+
"""Return a data point and its metadata information.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
index -- a random integer for data indexing
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
62 |
+
|
63 |
+
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
64 |
+
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
65 |
+
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
66 |
+
Step 4: return a data point as a dictionary.
|
67 |
+
"""
|
68 |
+
path = 'temp' # needs to be a string
|
69 |
+
data_A = None # needs to be a tensor
|
70 |
+
data_B = None # needs to be a tensor
|
71 |
+
return {'data_A': data_A, 'data_B': data_B, 'path': path}
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
"""Return the total number of images."""
|
75 |
+
return len(self.image_paths)
|
chat_anything/sad_talker/face3d/extract_kp_videos.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import face_alignment
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
from itertools import cycle
|
11 |
+
|
12 |
+
from torch.multiprocessing import Pool, Process, set_start_method
|
13 |
+
|
14 |
+
class KeypointExtractor():
|
15 |
+
def __init__(self, device):
|
16 |
+
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
|
17 |
+
device=device)
|
18 |
+
|
19 |
+
def extract_keypoint(self, images, name=None, info=True):
|
20 |
+
if isinstance(images, list):
|
21 |
+
keypoints = []
|
22 |
+
if info:
|
23 |
+
i_range = tqdm(images,desc='landmark Det:')
|
24 |
+
else:
|
25 |
+
i_range = images
|
26 |
+
|
27 |
+
for image in i_range:
|
28 |
+
current_kp = self.extract_keypoint(image)
|
29 |
+
if np.mean(current_kp) == -1 and keypoints:
|
30 |
+
keypoints.append(keypoints[-1])
|
31 |
+
else:
|
32 |
+
keypoints.append(current_kp[None])
|
33 |
+
|
34 |
+
keypoints = np.concatenate(keypoints, 0)
|
35 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
36 |
+
return keypoints
|
37 |
+
else:
|
38 |
+
while True:
|
39 |
+
try:
|
40 |
+
keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
|
41 |
+
break
|
42 |
+
except RuntimeError as e:
|
43 |
+
if str(e).startswith('CUDA'):
|
44 |
+
print("Warning: out of memory, sleep for 1s")
|
45 |
+
time.sleep(1)
|
46 |
+
else:
|
47 |
+
print(e)
|
48 |
+
break
|
49 |
+
except TypeError:
|
50 |
+
print('No face detected in this image')
|
51 |
+
shape = [68, 2]
|
52 |
+
keypoints = -1. * np.ones(shape)
|
53 |
+
break
|
54 |
+
if name is not None:
|
55 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
56 |
+
return keypoints
|
57 |
+
|
58 |
+
def read_video(filename):
|
59 |
+
frames = []
|
60 |
+
cap = cv2.VideoCapture(filename)
|
61 |
+
while cap.isOpened():
|
62 |
+
ret, frame = cap.read()
|
63 |
+
if ret:
|
64 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
65 |
+
frame = Image.fromarray(frame)
|
66 |
+
frames.append(frame)
|
67 |
+
else:
|
68 |
+
break
|
69 |
+
cap.release()
|
70 |
+
return frames
|
71 |
+
|
72 |
+
def run(data):
|
73 |
+
filename, opt, device = data
|
74 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
75 |
+
kp_extractor = KeypointExtractor()
|
76 |
+
images = read_video(filename)
|
77 |
+
name = filename.split('/')[-2:]
|
78 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
79 |
+
kp_extractor.extract_keypoint(
|
80 |
+
images,
|
81 |
+
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
82 |
+
)
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
set_start_method('spawn')
|
86 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
87 |
+
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
88 |
+
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
89 |
+
parser.add_argument('--device_ids', type=str, default='0,1')
|
90 |
+
parser.add_argument('--workers', type=int, default=4)
|
91 |
+
|
92 |
+
opt = parser.parse_args()
|
93 |
+
filenames = list()
|
94 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
95 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
96 |
+
extensions = VIDEO_EXTENSIONS
|
97 |
+
|
98 |
+
for ext in extensions:
|
99 |
+
os.listdir(f'{opt.input_dir}')
|
100 |
+
print(f'{opt.input_dir}/*.{ext}')
|
101 |
+
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
102 |
+
print('Total number of videos:', len(filenames))
|
103 |
+
pool = Pool(opt.workers)
|
104 |
+
args_list = cycle([opt])
|
105 |
+
device_ids = opt.device_ids.split(",")
|
106 |
+
device_ids = cycle(device_ids)
|
107 |
+
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
108 |
+
None
|
chat_anything/sad_talker/face3d/extract_kp_videos_safe.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from itertools import cycle
|
11 |
+
from torch.multiprocessing import Pool, Process, set_start_method
|
12 |
+
|
13 |
+
from facexlib.alignment import landmark_98_to_68
|
14 |
+
from facexlib.detection import init_detection_model
|
15 |
+
|
16 |
+
from facexlib.utils import load_file_from_url
|
17 |
+
from chat_anything.sad_talker.face3d.util.my_awing_arch import FAN
|
18 |
+
|
19 |
+
def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
|
20 |
+
if model_name == 'awing_fan':
|
21 |
+
model = FAN(num_modules=4, num_landmarks=98, device=device)
|
22 |
+
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f'{model_name} is not implemented.')
|
25 |
+
|
26 |
+
model_path = load_file_from_url(
|
27 |
+
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
28 |
+
model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
|
29 |
+
model.eval()
|
30 |
+
model = model.to(device)
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
class KeypointExtractor():
|
35 |
+
def __init__(self, device='cuda'):
|
36 |
+
|
37 |
+
### gfpgan/weights
|
38 |
+
try:
|
39 |
+
import webui # in webui
|
40 |
+
root_path = 'extensions/SadTalker/gfpgan/weights'
|
41 |
+
|
42 |
+
except:
|
43 |
+
# root_path = 'gfpgan/weights'
|
44 |
+
root_path = 'MODELS/gfpgan/weights'
|
45 |
+
|
46 |
+
self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
|
47 |
+
self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
|
48 |
+
|
49 |
+
def extract_keypoint(self, images, name=None, info=True):
|
50 |
+
if isinstance(images, list):
|
51 |
+
keypoints = []
|
52 |
+
if info:
|
53 |
+
i_range = tqdm(images,desc='landmark Det:')
|
54 |
+
else:
|
55 |
+
i_range = images
|
56 |
+
|
57 |
+
for image in i_range:
|
58 |
+
print("detect landmarks")
|
59 |
+
current_kp = self.extract_keypoint(image)
|
60 |
+
# current_kp = self.detector.get_landmarks(np.array(image))
|
61 |
+
if np.mean(current_kp) == -1 and keypoints:
|
62 |
+
keypoints.append(keypoints[-1])
|
63 |
+
else:
|
64 |
+
keypoints.append(current_kp[None])
|
65 |
+
|
66 |
+
keypoints = np.concatenate(keypoints, 0)
|
67 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
68 |
+
return keypoints
|
69 |
+
else:
|
70 |
+
print("here")
|
71 |
+
while True:
|
72 |
+
try:
|
73 |
+
with torch.no_grad():
|
74 |
+
# face detection -> face alignment.
|
75 |
+
img = np.array(images)
|
76 |
+
bboxes = self.det_net.detect_faces(images, 0.97)
|
77 |
+
|
78 |
+
bboxes = bboxes[0]
|
79 |
+
img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
|
80 |
+
|
81 |
+
landmarks=self.detector.get_landmarks(img)
|
82 |
+
print(landmarks.shape)
|
83 |
+
start_time=time.time()
|
84 |
+
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
|
85 |
+
end_time=time.time()
|
86 |
+
print(type(keypoints))
|
87 |
+
print(keypoints.shape)
|
88 |
+
|
89 |
+
elapsed_time = end_time - start_time # 计算时间差
|
90 |
+
print("landmark检测时间:%.4f秒" % elapsed_time)
|
91 |
+
#### keypoints to the original location
|
92 |
+
keypoints[:,0] += int(bboxes[0])
|
93 |
+
keypoints[:,1] += int(bboxes[1])
|
94 |
+
|
95 |
+
break
|
96 |
+
except RuntimeError as e:
|
97 |
+
if str(e).startswith('CUDA'):
|
98 |
+
print("Warning: out of memory, sleep for 1s")
|
99 |
+
time.sleep(1)
|
100 |
+
else:
|
101 |
+
print(e)
|
102 |
+
break
|
103 |
+
except TypeError:
|
104 |
+
print('No face detected in this image')
|
105 |
+
shape = [68, 2]
|
106 |
+
keypoints = -1. * np.ones(shape)
|
107 |
+
break
|
108 |
+
if name is not None:
|
109 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
110 |
+
return keypoints
|
111 |
+
|
112 |
+
def read_video(filename):
|
113 |
+
frames = []
|
114 |
+
cap = cv2.VideoCapture(filename)
|
115 |
+
while cap.isOpened():
|
116 |
+
ret, frame = cap.read()
|
117 |
+
if ret:
|
118 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
119 |
+
frame = Image.fromarray(frame)
|
120 |
+
frames.append(frame)
|
121 |
+
else:
|
122 |
+
break
|
123 |
+
cap.release()
|
124 |
+
return frames
|
125 |
+
|
126 |
+
def run(data):
|
127 |
+
filename, opt, device = data
|
128 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
129 |
+
kp_extractor = KeypointExtractor()
|
130 |
+
images = read_video(filename)
|
131 |
+
name = filename.split('/')[-2:]
|
132 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
133 |
+
kp_extractor.extract_keypoint(
|
134 |
+
images,
|
135 |
+
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
136 |
+
)
|
137 |
+
|
138 |
+
if __name__ == '__main__':
|
139 |
+
set_start_method('spawn')
|
140 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
141 |
+
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
142 |
+
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
143 |
+
parser.add_argument('--device_ids', type=str, default='0,1')
|
144 |
+
parser.add_argument('--workers', type=int, default=4)
|
145 |
+
|
146 |
+
opt = parser.parse_args()
|
147 |
+
filenames = list()
|
148 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
149 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
150 |
+
extensions = VIDEO_EXTENSIONS
|
151 |
+
|
152 |
+
for ext in extensions:
|
153 |
+
os.listdir(f'{opt.input_dir}')
|
154 |
+
print(f'{opt.input_dir}/*.{ext}')
|
155 |
+
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
156 |
+
print('Total number of videos:', len(filenames))
|
157 |
+
pool = Pool(opt.workers)
|
158 |
+
args_list = cycle([opt])
|
159 |
+
device_ids = opt.device_ids.split(",")
|
160 |
+
device_ids = cycle(device_ids)
|
161 |
+
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
162 |
+
None
|
chat_anything/sad_talker/face3d/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from chat_anything.sad_talker.face3d.models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "face3d.models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
chat_anything/sad_talker/face3d/models/arcface_torch/README.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributed Arcface Training in Pytorch
|
2 |
+
|
3 |
+
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
|
4 |
+
identity on a single server.
|
5 |
+
|
6 |
+
## Requirements
|
7 |
+
|
8 |
+
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
|
9 |
+
- `pip install -r requirements.txt`.
|
10 |
+
- Download the dataset
|
11 |
+
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
|
12 |
+
.
|
13 |
+
|
14 |
+
## How to Training
|
15 |
+
|
16 |
+
To train a model, run `train.py` with the path to the configs:
|
17 |
+
|
18 |
+
### 1. Single node, 8 GPUs:
|
19 |
+
|
20 |
+
```shell
|
21 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
22 |
+
```
|
23 |
+
|
24 |
+
### 2. Multiple nodes, each node 8 GPUs:
|
25 |
+
|
26 |
+
Node 0:
|
27 |
+
|
28 |
+
```shell
|
29 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
30 |
+
```
|
31 |
+
|
32 |
+
Node 1:
|
33 |
+
|
34 |
+
```shell
|
35 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
36 |
+
```
|
37 |
+
|
38 |
+
### 3.Training resnet2060 with 8 GPUs:
|
39 |
+
|
40 |
+
```shell
|
41 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
|
42 |
+
```
|
43 |
+
|
44 |
+
## Model Zoo
|
45 |
+
|
46 |
+
- The models are available for non-commercial research purposes only.
|
47 |
+
- All models can be found in here.
|
48 |
+
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
49 |
+
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
50 |
+
|
51 |
+
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
|
52 |
+
|
53 |
+
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
54 |
+
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
55 |
+
As the result, we can evaluate the FAIR performance for different algorithms.
|
56 |
+
|
57 |
+
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
58 |
+
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
59 |
+
|
60 |
+
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
|
61 |
+
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
|
62 |
+
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
|
63 |
+
|
64 |
+
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
|
65 |
+
| :---: | :--- | :--- | :--- |:--- |:--- |
|
66 |
+
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
|
67 |
+
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
|
68 |
+
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
|
69 |
+
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
|
70 |
+
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
|
71 |
+
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
|
72 |
+
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
|
73 |
+
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
|
74 |
+
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
|
75 |
+
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
|
76 |
+
|
77 |
+
### Performance on IJB-C and Verification Datasets
|
78 |
+
|
79 |
+
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
|
80 |
+
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
|
81 |
+
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
|
82 |
+
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
|
83 |
+
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
|
84 |
+
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
|
85 |
+
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
|
86 |
+
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
|
87 |
+
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
|
88 |
+
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
|
89 |
+
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
|
90 |
+
|
91 |
+
[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
|
92 |
+
|
93 |
+
|
94 |
+
## [Speed Benchmark](docs/speed_benchmark.md)
|
95 |
+
|
96 |
+
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
|
97 |
+
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
|
98 |
+
accuracy with several times faster training performance and smaller GPU memory.
|
99 |
+
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
|
100 |
+
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
|
101 |
+
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
|
102 |
+
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
|
103 |
+
training and mixed precision training.
|
104 |
+
|
105 |
+
![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
|
106 |
+
|
107 |
+
More details see
|
108 |
+
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
|
109 |
+
|
110 |
+
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
|
111 |
+
|
112 |
+
`-` means training failed because of gpu memory limitations.
|
113 |
+
|
114 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
115 |
+
| :--- | :--- | :--- | :--- |
|
116 |
+
|125000 | 4681 | 4824 | 5004 |
|
117 |
+
|1400000 | **1672** | 3043 | 4738 |
|
118 |
+
|5500000 | **-** | **1389** | 3975 |
|
119 |
+
|8000000 | **-** | **-** | 3565 |
|
120 |
+
|16000000 | **-** | **-** | 2679 |
|
121 |
+
|29000000 | **-** | **-** | **1855** |
|
122 |
+
|
123 |
+
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
|
124 |
+
|
125 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
126 |
+
| :--- | :--- | :--- | :--- |
|
127 |
+
|125000 | 7358 | 5306 | 4868 |
|
128 |
+
|1400000 | 32252 | 11178 | 6056 |
|
129 |
+
|5500000 | **-** | 32188 | 9854 |
|
130 |
+
|8000000 | **-** | **-** | 12310 |
|
131 |
+
|16000000 | **-** | **-** | 19950 |
|
132 |
+
|29000000 | **-** | **-** | 32324 |
|
133 |
+
|
134 |
+
## Evaluation ICCV2021-MFR and IJB-C
|
135 |
+
|
136 |
+
More details see [eval.md](docs/eval.md) in docs.
|
137 |
+
|
138 |
+
## Test
|
139 |
+
|
140 |
+
We tested many versions of PyTorch. Please create an issue if you are having trouble.
|
141 |
+
|
142 |
+
- [x] torch 1.6.0
|
143 |
+
- [x] torch 1.7.1
|
144 |
+
- [x] torch 1.8.0
|
145 |
+
- [x] torch 1.9.0
|
146 |
+
|
147 |
+
## Citation
|
148 |
+
|
149 |
+
```
|
150 |
+
@inproceedings{deng2019arcface,
|
151 |
+
title={Arcface: Additive angular margin loss for deep face recognition},
|
152 |
+
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
153 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
154 |
+
pages={4690--4699},
|
155 |
+
year={2019}
|
156 |
+
}
|
157 |
+
@inproceedings{an2020partical_fc,
|
158 |
+
title={Partial FC: Training 10 Million Identities on a Single Machine},
|
159 |
+
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
|
160 |
+
Zhang, Debing and Fu Ying},
|
161 |
+
booktitle={Arxiv 2010.05222},
|
162 |
+
year={2020}
|
163 |
+
}
|
164 |
+
```
|
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
|
2 |
+
from .mobilefacenet import get_mbf
|
3 |
+
|
4 |
+
|
5 |
+
def get_model(name, **kwargs):
|
6 |
+
# resnet
|
7 |
+
if name == "r18":
|
8 |
+
return iresnet18(False, **kwargs)
|
9 |
+
elif name == "r34":
|
10 |
+
return iresnet34(False, **kwargs)
|
11 |
+
elif name == "r50":
|
12 |
+
return iresnet50(False, **kwargs)
|
13 |
+
elif name == "r100":
|
14 |
+
return iresnet100(False, **kwargs)
|
15 |
+
elif name == "r200":
|
16 |
+
return iresnet200(False, **kwargs)
|
17 |
+
elif name == "r2060":
|
18 |
+
from .iresnet2060 import iresnet2060
|
19 |
+
return iresnet2060(False, **kwargs)
|
20 |
+
elif name == "mbf":
|
21 |
+
fp16 = kwargs.get("fp16", False)
|
22 |
+
num_features = kwargs.get("num_features", 512)
|
23 |
+
return get_mbf(fp16=fp16, num_features=num_features)
|
24 |
+
else:
|
25 |
+
raise ValueError()
|
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
8 |
+
"""3x3 convolution with padding"""
|
9 |
+
return nn.Conv2d(in_planes,
|
10 |
+
out_planes,
|
11 |
+
kernel_size=3,
|
12 |
+
stride=stride,
|
13 |
+
padding=dilation,
|
14 |
+
groups=groups,
|
15 |
+
bias=False,
|
16 |
+
dilation=dilation)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return nn.Conv2d(in_planes,
|
22 |
+
out_planes,
|
23 |
+
kernel_size=1,
|
24 |
+
stride=stride,
|
25 |
+
bias=False)
|
26 |
+
|
27 |
+
|
28 |
+
class IBasicBlock(nn.Module):
|
29 |
+
expansion = 1
|
30 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
31 |
+
groups=1, base_width=64, dilation=1):
|
32 |
+
super(IBasicBlock, self).__init__()
|
33 |
+
if groups != 1 or base_width != 64:
|
34 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
35 |
+
if dilation > 1:
|
36 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
37 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
38 |
+
self.conv1 = conv3x3(inplanes, planes)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
40 |
+
self.prelu = nn.PReLU(planes)
|
41 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
42 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
43 |
+
self.downsample = downsample
|
44 |
+
self.stride = stride
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
identity = x
|
48 |
+
out = self.bn1(x)
|
49 |
+
out = self.conv1(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
out = self.prelu(out)
|
52 |
+
out = self.conv2(out)
|
53 |
+
out = self.bn3(out)
|
54 |
+
if self.downsample is not None:
|
55 |
+
identity = self.downsample(x)
|
56 |
+
out += identity
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class IResNet(nn.Module):
|
61 |
+
fc_scale = 7 * 7
|
62 |
+
def __init__(self,
|
63 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
64 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
65 |
+
super(IResNet, self).__init__()
|
66 |
+
self.fp16 = fp16
|
67 |
+
self.inplanes = 64
|
68 |
+
self.dilation = 1
|
69 |
+
if replace_stride_with_dilation is None:
|
70 |
+
replace_stride_with_dilation = [False, False, False]
|
71 |
+
if len(replace_stride_with_dilation) != 3:
|
72 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
73 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
74 |
+
self.groups = groups
|
75 |
+
self.base_width = width_per_group
|
76 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
77 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
78 |
+
self.prelu = nn.PReLU(self.inplanes)
|
79 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
80 |
+
self.layer2 = self._make_layer(block,
|
81 |
+
128,
|
82 |
+
layers[1],
|
83 |
+
stride=2,
|
84 |
+
dilate=replace_stride_with_dilation[0])
|
85 |
+
self.layer3 = self._make_layer(block,
|
86 |
+
256,
|
87 |
+
layers[2],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[1])
|
90 |
+
self.layer4 = self._make_layer(block,
|
91 |
+
512,
|
92 |
+
layers[3],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[2])
|
95 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
96 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
97 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
98 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
99 |
+
nn.init.constant_(self.features.weight, 1.0)
|
100 |
+
self.features.weight.requires_grad = False
|
101 |
+
|
102 |
+
for m in self.modules():
|
103 |
+
if isinstance(m, nn.Conv2d):
|
104 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
105 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
106 |
+
nn.init.constant_(m.weight, 1)
|
107 |
+
nn.init.constant_(m.bias, 0)
|
108 |
+
|
109 |
+
if zero_init_residual:
|
110 |
+
for m in self.modules():
|
111 |
+
if isinstance(m, IBasicBlock):
|
112 |
+
nn.init.constant_(m.bn2.weight, 0)
|
113 |
+
|
114 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
115 |
+
downsample = None
|
116 |
+
previous_dilation = self.dilation
|
117 |
+
if dilate:
|
118 |
+
self.dilation *= stride
|
119 |
+
stride = 1
|
120 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
121 |
+
downsample = nn.Sequential(
|
122 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
123 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
124 |
+
)
|
125 |
+
layers = []
|
126 |
+
layers.append(
|
127 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
128 |
+
self.base_width, previous_dilation))
|
129 |
+
self.inplanes = planes * block.expansion
|
130 |
+
for _ in range(1, blocks):
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes,
|
133 |
+
planes,
|
134 |
+
groups=self.groups,
|
135 |
+
base_width=self.base_width,
|
136 |
+
dilation=self.dilation))
|
137 |
+
|
138 |
+
return nn.Sequential(*layers)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
with torch.cuda.amp.autocast(self.fp16):
|
142 |
+
x = self.conv1(x)
|
143 |
+
x = self.bn1(x)
|
144 |
+
x = self.prelu(x)
|
145 |
+
x = self.layer1(x)
|
146 |
+
x = self.layer2(x)
|
147 |
+
x = self.layer3(x)
|
148 |
+
x = self.layer4(x)
|
149 |
+
x = self.bn2(x)
|
150 |
+
x = torch.flatten(x, 1)
|
151 |
+
x = self.dropout(x)
|
152 |
+
x = self.fc(x.float() if self.fp16 else x)
|
153 |
+
x = self.features(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
158 |
+
model = IResNet(block, layers, **kwargs)
|
159 |
+
if pretrained:
|
160 |
+
raise ValueError()
|
161 |
+
return model
|
162 |
+
|
163 |
+
|
164 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
165 |
+
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
166 |
+
progress, **kwargs)
|
167 |
+
|
168 |
+
|
169 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
170 |
+
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
171 |
+
progress, **kwargs)
|
172 |
+
|
173 |
+
|
174 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
175 |
+
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
176 |
+
progress, **kwargs)
|
177 |
+
|
178 |
+
|
179 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
180 |
+
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
181 |
+
progress, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
185 |
+
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
|
186 |
+
progress, **kwargs)
|
187 |
+
|
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
assert torch.__version__ >= "1.8.1"
|
5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
6 |
+
|
7 |
+
__all__ = ['iresnet2060']
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
11 |
+
"""3x3 convolution with padding"""
|
12 |
+
return nn.Conv2d(in_planes,
|
13 |
+
out_planes,
|
14 |
+
kernel_size=3,
|
15 |
+
stride=stride,
|
16 |
+
padding=dilation,
|
17 |
+
groups=groups,
|
18 |
+
bias=False,
|
19 |
+
dilation=dilation)
|
20 |
+
|
21 |
+
|
22 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
23 |
+
"""1x1 convolution"""
|
24 |
+
return nn.Conv2d(in_planes,
|
25 |
+
out_planes,
|
26 |
+
kernel_size=1,
|
27 |
+
stride=stride,
|
28 |
+
bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class IBasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
35 |
+
groups=1, base_width=64, dilation=1):
|
36 |
+
super(IBasicBlock, self).__init__()
|
37 |
+
if groups != 1 or base_width != 64:
|
38 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
39 |
+
if dilation > 1:
|
40 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
41 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
|
42 |
+
self.conv1 = conv3x3(inplanes, planes)
|
43 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
|
44 |
+
self.prelu = nn.PReLU(planes)
|
45 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
46 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
|
47 |
+
self.downsample = downsample
|
48 |
+
self.stride = stride
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
identity = x
|
52 |
+
out = self.bn1(x)
|
53 |
+
out = self.conv1(out)
|
54 |
+
out = self.bn2(out)
|
55 |
+
out = self.prelu(out)
|
56 |
+
out = self.conv2(out)
|
57 |
+
out = self.bn3(out)
|
58 |
+
if self.downsample is not None:
|
59 |
+
identity = self.downsample(x)
|
60 |
+
out += identity
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class IResNet(nn.Module):
|
65 |
+
fc_scale = 7 * 7
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
69 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
70 |
+
super(IResNet, self).__init__()
|
71 |
+
self.fp16 = fp16
|
72 |
+
self.inplanes = 64
|
73 |
+
self.dilation = 1
|
74 |
+
if replace_stride_with_dilation is None:
|
75 |
+
replace_stride_with_dilation = [False, False, False]
|
76 |
+
if len(replace_stride_with_dilation) != 3:
|
77 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
78 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
79 |
+
self.groups = groups
|
80 |
+
self.base_width = width_per_group
|
81 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
82 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
83 |
+
self.prelu = nn.PReLU(self.inplanes)
|
84 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
85 |
+
self.layer2 = self._make_layer(block,
|
86 |
+
128,
|
87 |
+
layers[1],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[0])
|
90 |
+
self.layer3 = self._make_layer(block,
|
91 |
+
256,
|
92 |
+
layers[2],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[1])
|
95 |
+
self.layer4 = self._make_layer(block,
|
96 |
+
512,
|
97 |
+
layers[3],
|
98 |
+
stride=2,
|
99 |
+
dilate=replace_stride_with_dilation[2])
|
100 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
|
101 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
102 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
103 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
104 |
+
nn.init.constant_(self.features.weight, 1.0)
|
105 |
+
self.features.weight.requires_grad = False
|
106 |
+
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.Conv2d):
|
109 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
110 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
111 |
+
nn.init.constant_(m.weight, 1)
|
112 |
+
nn.init.constant_(m.bias, 0)
|
113 |
+
|
114 |
+
if zero_init_residual:
|
115 |
+
for m in self.modules():
|
116 |
+
if isinstance(m, IBasicBlock):
|
117 |
+
nn.init.constant_(m.bn2.weight, 0)
|
118 |
+
|
119 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
120 |
+
downsample = None
|
121 |
+
previous_dilation = self.dilation
|
122 |
+
if dilate:
|
123 |
+
self.dilation *= stride
|
124 |
+
stride = 1
|
125 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
126 |
+
downsample = nn.Sequential(
|
127 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
128 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
129 |
+
)
|
130 |
+
layers = []
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
133 |
+
self.base_width, previous_dilation))
|
134 |
+
self.inplanes = planes * block.expansion
|
135 |
+
for _ in range(1, blocks):
|
136 |
+
layers.append(
|
137 |
+
block(self.inplanes,
|
138 |
+
planes,
|
139 |
+
groups=self.groups,
|
140 |
+
base_width=self.base_width,
|
141 |
+
dilation=self.dilation))
|
142 |
+
|
143 |
+
return nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def checkpoint(self, func, num_seg, x):
|
146 |
+
if self.training:
|
147 |
+
return checkpoint_sequential(func, num_seg, x)
|
148 |
+
else:
|
149 |
+
return func(x)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
with torch.cuda.amp.autocast(self.fp16):
|
153 |
+
x = self.conv1(x)
|
154 |
+
x = self.bn1(x)
|
155 |
+
x = self.prelu(x)
|
156 |
+
x = self.layer1(x)
|
157 |
+
x = self.checkpoint(self.layer2, 20, x)
|
158 |
+
x = self.checkpoint(self.layer3, 100, x)
|
159 |
+
x = self.layer4(x)
|
160 |
+
x = self.bn2(x)
|
161 |
+
x = torch.flatten(x, 1)
|
162 |
+
x = self.dropout(x)
|
163 |
+
x = self.fc(x.float() if self.fp16 else x)
|
164 |
+
x = self.features(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
169 |
+
model = IResNet(block, layers, **kwargs)
|
170 |
+
if pretrained:
|
171 |
+
raise ValueError()
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
def iresnet2060(pretrained=False, progress=True, **kwargs):
|
176 |
+
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
|
chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
|
3 |
+
Original author cavalleria
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, x):
|
13 |
+
return x.view(x.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
class ConvBlock(Module):
|
17 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
18 |
+
super(ConvBlock, self).__init__()
|
19 |
+
self.layers = nn.Sequential(
|
20 |
+
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
|
21 |
+
BatchNorm2d(num_features=out_c),
|
22 |
+
PReLU(num_parameters=out_c)
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.layers(x)
|
27 |
+
|
28 |
+
|
29 |
+
class LinearBlock(Module):
|
30 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
31 |
+
super(LinearBlock, self).__init__()
|
32 |
+
self.layers = nn.Sequential(
|
33 |
+
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
34 |
+
BatchNorm2d(num_features=out_c)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layers(x)
|
39 |
+
|
40 |
+
|
41 |
+
class DepthWise(Module):
|
42 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
43 |
+
super(DepthWise, self).__init__()
|
44 |
+
self.residual = residual
|
45 |
+
self.layers = nn.Sequential(
|
46 |
+
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
47 |
+
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
|
48 |
+
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
short_cut = None
|
53 |
+
if self.residual:
|
54 |
+
short_cut = x
|
55 |
+
x = self.layers(x)
|
56 |
+
if self.residual:
|
57 |
+
output = short_cut + x
|
58 |
+
else:
|
59 |
+
output = x
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
class Residual(Module):
|
64 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
65 |
+
super(Residual, self).__init__()
|
66 |
+
modules = []
|
67 |
+
for _ in range(num_block):
|
68 |
+
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
|
69 |
+
self.layers = Sequential(*modules)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return self.layers(x)
|
73 |
+
|
74 |
+
|
75 |
+
class GDC(Module):
|
76 |
+
def __init__(self, embedding_size):
|
77 |
+
super(GDC, self).__init__()
|
78 |
+
self.layers = nn.Sequential(
|
79 |
+
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
|
80 |
+
Flatten(),
|
81 |
+
Linear(512, embedding_size, bias=False),
|
82 |
+
BatchNorm1d(embedding_size))
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
return self.layers(x)
|
86 |
+
|
87 |
+
|
88 |
+
class MobileFaceNet(Module):
|
89 |
+
def __init__(self, fp16=False, num_features=512):
|
90 |
+
super(MobileFaceNet, self).__init__()
|
91 |
+
scale = 2
|
92 |
+
self.fp16 = fp16
|
93 |
+
self.layers = nn.Sequential(
|
94 |
+
ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
|
95 |
+
ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
|
96 |
+
DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
|
97 |
+
Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
98 |
+
DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
|
99 |
+
Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
100 |
+
DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
|
101 |
+
Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
102 |
+
)
|
103 |
+
self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
104 |
+
self.features = GDC(num_features)
|
105 |
+
self._initialize_weights()
|
106 |
+
|
107 |
+
def _initialize_weights(self):
|
108 |
+
for m in self.modules():
|
109 |
+
if isinstance(m, nn.Conv2d):
|
110 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
111 |
+
if m.bias is not None:
|
112 |
+
m.bias.data.zero_()
|
113 |
+
elif isinstance(m, nn.BatchNorm2d):
|
114 |
+
m.weight.data.fill_(1)
|
115 |
+
m.bias.data.zero_()
|
116 |
+
elif isinstance(m, nn.Linear):
|
117 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
118 |
+
if m.bias is not None:
|
119 |
+
m.bias.data.zero_()
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
with torch.cuda.amp.autocast(self.fp16):
|
123 |
+
x = self.layers(x)
|
124 |
+
x = self.conv_sep(x.float() if self.fp16 else x)
|
125 |
+
x = self.features(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def get_mbf(fp16, num_features):
|
130 |
+
return MobileFaceNet(fp16, num_features)
|
chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 1.0
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 300 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 0.1
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 300 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
chat_anything/sad_talker/face3d/models/arcface_torch/configs/__init__.py
ADDED
File without changes
|
chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = "ms1mv3_arcface_r50"
|
12 |
+
|
13 |
+
config.dataset = "ms1m-retinaface-t1"
|
14 |
+
config.embedding_size = 512
|
15 |
+
config.sample_rate = 1
|
16 |
+
config.fp16 = False
|
17 |
+
config.momentum = 0.9
|
18 |
+
config.weight_decay = 5e-4
|
19 |
+
config.batch_size = 128
|
20 |
+
config.lr = 0.1 # batch size is 512
|
21 |
+
|
22 |
+
if config.dataset == "emore":
|
23 |
+
config.rec = "/train_tmp/faces_emore"
|
24 |
+
config.num_classes = 85742
|
25 |
+
config.num_image = 5822653
|
26 |
+
config.num_epoch = 16
|
27 |
+
config.warmup_epoch = -1
|
28 |
+
config.decay_epoch = [8, 14, ]
|
29 |
+
config.val_targets = ["lfw", ]
|
30 |
+
|
31 |
+
elif config.dataset == "ms1m-retinaface-t1":
|
32 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
33 |
+
config.num_classes = 93431
|
34 |
+
config.num_image = 5179510
|
35 |
+
config.num_epoch = 25
|
36 |
+
config.warmup_epoch = -1
|
37 |
+
config.decay_epoch = [11, 17, 22]
|
38 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
39 |
+
|
40 |
+
elif config.dataset == "glint360k":
|
41 |
+
config.rec = "/train_tmp/glint360k"
|
42 |
+
config.num_classes = 360232
|
43 |
+
config.num_image = 17091657
|
44 |
+
config.num_epoch = 20
|
45 |
+
config.warmup_epoch = -1
|
46 |
+
config.decay_epoch = [8, 12, 15, 18]
|
47 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
48 |
+
|
49 |
+
elif config.dataset == "webface":
|
50 |
+
config.rec = "/train_tmp/faces_webface_112x112"
|
51 |
+
config.num_classes = 10572
|
52 |
+
config.num_image = "forget"
|
53 |
+
config.num_epoch = 34
|
54 |
+
config.warmup_epoch = -1
|
55 |
+
config.decay_epoch = [20, 28, 32]
|
56 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 0.1
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 2e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r100"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|