Spaces:
Running
on
Zero
Running
on
Zero
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 000000000285.jpg +0 -0
- 000000000724.jpg +0 -0
- 000000007991.jpg +0 -0
- 000000018837.jpg +0 -0
- 000000122962.jpg +0 -0
- 000000295478.jpg +0 -0
- ORIGINAL_README.md +128 -0
- eval_controlnet.py +148 -0
- eval_controlnet.sh +19 -0
- eval_controlnet_sdxl_light.py +284 -0
- eval_controlnet_sdxl_light.sh +44 -0
- eval_controlnet_sdxl_light_single.py +390 -0
- eval_controlnet_sdxl_light_single.sh +20 -0
- example/UUColor_results/Hollywood-Sign.jpeg +0 -0
- example/legacy_images/Big-Ben-vintage.jpg +0 -0
- example/legacy_images/Central-Park.jpg +0 -0
- example/legacy_images/Hollywood-Sign.jpg +0 -0
- example/legacy_images/Little-Mermaid.jpg +0 -0
- example/legacy_images/Migrant-Mother.jpg +0 -0
- example/legacy_images/Mount-Everest.jpg +0 -0
- example/legacy_images/Tower-of-Pisa.jpg +0 -0
- example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg +0 -0
- gradio_ui.py +356 -0
- images/000000022935_gray.jpg +0 -0
- images/000000022935_green_shirt_on_right_girl.jpeg +0 -0
- images/000000022935_purple_shirt_on_right_girl.jpeg +0 -0
- images/000000022935_red_shirt_on_right_girl.jpeg +0 -0
- images/000000025560_color.jpg +0 -0
- images/000000025560_gray.jpg +0 -0
- images/000000025560_gt.jpg +0 -0
- images/000000041633_black_car.jpeg +0 -0
- images/000000041633_bright_red_car.jpeg +0 -0
- images/000000041633_dark_blue_car.jpeg +0 -0
- images/000000041633_gray.jpg +0 -0
- images/000000065736_color.jpg +0 -0
- images/000000065736_gray.jpg +0 -0
- images/000000065736_gt.jpg +0 -0
- images/000000091779_color.jpg +0 -0
- images/000000091779_gray.jpg +0 -0
- images/000000091779_gt.jpg +0 -0
- images/000000092177_color.jpg +0 -0
- images/000000092177_gray.jpg +0 -0
- images/000000092177_gt.jpg +0 -0
- images/000000166426_color.jpg +0 -0
- images/000000166426_gray.jpg +0 -0
- images/000000166426_gt.jpg +0 -0
- images/000000286708_gray.jpg +0 -0
- images/000000286708_orange_hat.jpeg +0 -0
- images/000000286708_pink_hat.jpeg +0 -0
- images/000000286708_yellow_hat.jpeg +0 -0
000000000285.jpg
ADDED
000000000724.jpg
ADDED
000000007991.jpg
ADDED
000000018837.jpg
ADDED
000000122962.jpg
ADDED
000000295478.jpg
ADDED
ORIGINAL_README.md
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text-Guided-Image-Colorization
|
2 |
+
|
3 |
+
This project utilizes the power of **Stable Diffusion (SDXL/SDXL-Light)** and the **BLIP (Bootstrapping Language-Image Pre-training)** captioning model to provide an interactive image colorization experience. Users can influence the generated colors of objects within images, making the colorization process more personalized and creative.
|
4 |
+
|
5 |
+
## Table of Contents
|
6 |
+
- [Features](#features)
|
7 |
+
- [Installation](#installation)
|
8 |
+
- [Quick Start](#quick-start)
|
9 |
+
- [Dataset Usage](#dataset-usage)
|
10 |
+
- [Training](#training)
|
11 |
+
- [Evaluation](#evaluation)
|
12 |
+
- [Results](#results)
|
13 |
+
- [License](#license)
|
14 |
+
|
15 |
+
## Features
|
16 |
+
|
17 |
+
- **Interactive Colorization**: Users can specify desired colors for different objects in the image.
|
18 |
+
- **ControlNet Approach**: Enhanced colorization capabilities through retraining with ControlNet, allowing SDXL to better adapt to the image colorization task.
|
19 |
+
- **High-Quality Outputs**: Leverage the latest advancements in diffusion models to generate vibrant and realistic colorizations.
|
20 |
+
- **User-Friendly Interface**: Easy-to-use interface for seamless interaction with the model.
|
21 |
+
|
22 |
+
## Installation
|
23 |
+
|
24 |
+
To set up the project locally, follow these steps:
|
25 |
+
|
26 |
+
1. **Clone the Repository**:
|
27 |
+
|
28 |
+
```bash
|
29 |
+
git clone https://github.com/nick8592/text-guided-image-colorization.git
|
30 |
+
cd text-guided-image-colorization
|
31 |
+
```
|
32 |
+
|
33 |
+
2. **Install Dependencies**:
|
34 |
+
Make sure you have Python 3.7 or higher installed. Then, install the required packages:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
pip install -r requirements.txt
|
38 |
+
```
|
39 |
+
Install `torch` and `torchvision` matching your CUDA version:
|
40 |
+
```bash
|
41 |
+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cuXXX
|
42 |
+
```
|
43 |
+
Replace `XXX` with your CUDA version (e.g., `118` for CUDA 11.8). For more info, see [PyTorch Get Started](https://pytorch.org/get-started/locally/).
|
44 |
+
|
45 |
+
|
46 |
+
3. **Download Pre-trained Models**:
|
47 |
+
| Models | Hugging Face (Recommand) | Other |
|
48 |
+
|:---:|:---:|:---:|
|
49 |
+
|SDXL-Lightning Caption|[link](https://huggingface.co/nickpai/sdxl_light_caption_output)|[link](https://gofile.me/7uE8s/FlEhfpWPw) (2kNJfV)|
|
50 |
+
|SDXL-Lightning Custom Caption (Recommand)|[link](https://huggingface.co/nickpai/sdxl_light_custom_caption_output)|[link](https://gofile.me/7uE8s/AKmRq5sLR) (KW7Fpi)|
|
51 |
+
|
52 |
+
|
53 |
+
```bash
|
54 |
+
text-guided-image-colorization/sdxl_light_caption_output
|
55 |
+
└── checkpoint-30000
|
56 |
+
├── controlnet
|
57 |
+
│ ├── diffusion_pytorch_model.safetensors
|
58 |
+
│ └── config.json
|
59 |
+
├── optimizer.bin
|
60 |
+
├── random_states_0.pkl
|
61 |
+
├── scaler.pt
|
62 |
+
└── scheduler.bin
|
63 |
+
```
|
64 |
+
|
65 |
+
## Quick Start
|
66 |
+
|
67 |
+
1. Run the `gradio_ui.py` script:
|
68 |
+
|
69 |
+
```bash
|
70 |
+
python gradio_ui.py
|
71 |
+
```
|
72 |
+
|
73 |
+
2. Open the provided URL in your web browser to access the Gradio-based user interface.
|
74 |
+
|
75 |
+
3. Upload an image and use the interface to control the colors of specific objects in the image. But still the model can generate images without a specific prompt.
|
76 |
+
|
77 |
+
4. The model will generate a colorized version of the image based on your input (or automatic). See the [demo video](https://x.com/weichenpai/status/1829513077588631987).
|
78 |
+
![Gradio UI](images/gradio_ui.png)
|
79 |
+
|
80 |
+
|
81 |
+
## Dataset Usage
|
82 |
+
|
83 |
+
You can find more details about the dataset usage in the [Dataset-for-Image-Colorization](https://github.com/nick8592/Dataset-for-Image-Colorization).
|
84 |
+
|
85 |
+
## Training
|
86 |
+
|
87 |
+
For training, you can use one of the following scripts:
|
88 |
+
|
89 |
+
- `train_controlnet.sh`: Trains a model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1)
|
90 |
+
- `train_controlnet_sdxl.sh`: Trains a model using [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
91 |
+
- `train_controlnet_sdxl_light.sh`: Trains a model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
|
92 |
+
|
93 |
+
Although the training code for SDXL is provided, due to a lack of GPU resources, I wasn't able to train the model by myself. Therefore, there might be some errors when you try to train the model.
|
94 |
+
|
95 |
+
## Evaluation
|
96 |
+
|
97 |
+
For evaluation, you can use one of the following scripts:
|
98 |
+
|
99 |
+
- `eval_controlnet.sh`: Evaluates the model using [Stable Diffusion v2](https://huggingface.co/stabilityai/stable-diffusion-2-1) for a folder of images.
|
100 |
+
- `eval_controlnet_sdxl_light.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a folder of images.
|
101 |
+
- `eval_controlnet_sdxl_light_single.sh`: Evaluates the model using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) for a single image.
|
102 |
+
|
103 |
+
## Results
|
104 |
+
### Prompt-Guided
|
105 |
+
| Caption | Condition 1 | Condition 2 | Condition 3 |
|
106 |
+
|:---:|:---:|:---:|:---:|
|
107 |
+
| ![000000022935_gray.jpg](images/000000022935_gray.jpg) | ![000000022935_green_shirt_on_right_girl.jpeg](images/000000022935_green_shirt_on_right_girl.jpeg) | ![000000022935_purple_shirt_on_right_girl.jpeg](images/000000022935_purple_shirt_on_right_girl.jpeg) |![000000022935_red_shirt_on_right_girl.jpeg](images/000000022935_red_shirt_on_right_girl.jpeg) |
|
108 |
+
| a photography of a woman in a soccer uniform kicking a soccer ball | + "green shirt"| + "purple shirt" | + "red shirt" |
|
109 |
+
| ![000000041633_gray.jpg](images/000000041633_gray.jpg) | ![000000041633_bright_red_car.jpeg](images/000000041633_bright_red_car.jpeg) | ![000000041633_dark_blue_car.jpeg](images/000000041633_dark_blue_car.jpeg) |![000000041633_black_car.jpeg](images/000000041633_black_car.jpeg) |
|
110 |
+
| a photography of a photo of a truck | + "bright red car"| + "dark blue car" | + "black car" |
|
111 |
+
| ![000000286708_gray.jpg](images/000000286708_gray.jpg) | ![000000286708_orange_hat.jpeg](images/000000286708_orange_hat.jpeg) | ![000000286708_pink_hat.jpeg](images/000000286708_pink_hat.jpeg) |![000000286708_yellow_hat.jpeg](images/000000286708_yellow_hat.jpeg) |
|
112 |
+
| a photography of a cat wearing a hat on his head | + "orange hat"| + "pink hat" | + "yellow hat" |
|
113 |
+
|
114 |
+
### Prompt-Free
|
115 |
+
Ground truth images are provided solely for reference purpose in the image colorization task.
|
116 |
+
| Grayscale Image | Colorized Result | Ground Truth |
|
117 |
+
|:---:|:---:|:---:|
|
118 |
+
| ![000000025560_gray.jpg](images/000000025560_gray.jpg) | ![000000025560_color.jpg](images/000000025560_color.jpg) | ![000000025560_gt.jpg](images/000000025560_gt.jpg) |
|
119 |
+
| ![000000065736_gray.jpg](images/000000065736_gray.jpg) | ![000000065736_color.jpg](images/000000065736_color.jpg) | ![000000065736_gt.jpg](images/000000065736_gt.jpg) |
|
120 |
+
| ![000000091779_gray.jpg](images/000000091779_gray.jpg) | ![000000091779_color.jpg](images/000000091779_color.jpg) | ![000000091779_gt.jpg](images/000000091779_gt.jpg) |
|
121 |
+
| ![000000092177_gray.jpg](images/000000092177_gray.jpg) | ![000000092177_color.jpg](images/000000092177_color.jpg) | ![000000092177_gt.jpg](images/000000092177_gt.jpg) |
|
122 |
+
| ![000000166426_gray.jpg](images/000000166426_gray.jpg) | ![000000166426_color.jpg](images/000000166426_color.jpg) | ![000000025560_gt.jpg](images/000000166426_gt.jpg) |
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
## License
|
127 |
+
|
128 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.
|
eval_controlnet.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import shutil
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
from datasets import load_dataset
|
11 |
+
from diffusers.utils import load_image
|
12 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
13 |
+
|
14 |
+
# Define the function to parse arguments
|
15 |
+
def parse_args(input_args=None):
|
16 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
|
17 |
+
|
18 |
+
parser.add_argument("--model_dir", type=str, default="sd_v2_caption_free_output/checkpoint-22500",
|
19 |
+
help="Directory of the model checkpoint")
|
20 |
+
parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-base",
|
21 |
+
help="ID of the model (Tested with runwayml/stable-diffusion-v1-5 and stabilityai/stable-diffusion-2-base)")
|
22 |
+
parser.add_argument("--dataset", type=str, default="nickpai/coco2017-colorization",
|
23 |
+
help="Dataset used")
|
24 |
+
parser.add_argument("--revision", type=str, default="caption-free",
|
25 |
+
choices=["main", "caption-free"],
|
26 |
+
help="Revision option (main/caption-free)")
|
27 |
+
|
28 |
+
if input_args is not None:
|
29 |
+
args = parser.parse_args(input_args)
|
30 |
+
else:
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
return args
|
34 |
+
|
35 |
+
def apply_color(image, color_map):
|
36 |
+
# Convert input images to LAB color space
|
37 |
+
image_lab = image.convert('LAB')
|
38 |
+
color_map_lab = color_map.convert('LAB')
|
39 |
+
|
40 |
+
# Split LAB channels
|
41 |
+
l, a, b = image_lab.split()
|
42 |
+
_, a_map, b_map = color_map_lab.split()
|
43 |
+
|
44 |
+
# Merge LAB channels with color map
|
45 |
+
merged_lab = Image.merge('LAB', (l, a_map, b_map))
|
46 |
+
|
47 |
+
# Convert merged LAB image back to RGB color space
|
48 |
+
result_rgb = merged_lab.convert('RGB')
|
49 |
+
|
50 |
+
return result_rgb
|
51 |
+
|
52 |
+
def main(args):
|
53 |
+
generator = torch.manual_seed(0)
|
54 |
+
|
55 |
+
# MODEL_DIR = "sd_v2_caption_free_output/checkpoint-22500"
|
56 |
+
# # MODEL_ID="runwayml/stable-diffusion-v1-5"
|
57 |
+
# MODEL_ID="stabilityai/stable-diffusion-2-base"
|
58 |
+
# DATASET = "nickpai/coco2017-colorization"
|
59 |
+
# REVISION = "caption-free" # option: main/caption-free
|
60 |
+
|
61 |
+
# Path to the eval_results folder
|
62 |
+
eval_results_folder = os.path.join(args.model_dir, "results")
|
63 |
+
|
64 |
+
# Remove eval_results folder if it exists
|
65 |
+
if os.path.exists(eval_results_folder):
|
66 |
+
shutil.rmtree(eval_results_folder)
|
67 |
+
|
68 |
+
# Create directory for eval_results
|
69 |
+
os.makedirs(eval_results_folder)
|
70 |
+
|
71 |
+
# Create subfolders for compare and colorized images
|
72 |
+
compare_folder = os.path.join(eval_results_folder, "compare")
|
73 |
+
colorized_folder = os.path.join(eval_results_folder, "colorized")
|
74 |
+
os.makedirs(compare_folder)
|
75 |
+
os.makedirs(colorized_folder)
|
76 |
+
|
77 |
+
# Load the validation split of the colorization dataset
|
78 |
+
val_dataset = load_dataset(args.dataset, split="validation", revision=args.revision)
|
79 |
+
|
80 |
+
controlnet = ControlNetModel.from_pretrained(f"{args.model_dir}/controlnet", torch_dtype=torch.float16)
|
81 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
82 |
+
args.model_id, controlnet=controlnet, torch_dtype=torch.float16
|
83 |
+
).to("cuda")
|
84 |
+
|
85 |
+
pipe.safety_checker = None
|
86 |
+
|
87 |
+
# Counter for processed images
|
88 |
+
processed_images = 0
|
89 |
+
|
90 |
+
# Record start time
|
91 |
+
start_time = time.time()
|
92 |
+
|
93 |
+
# Iterate through the validation dataset
|
94 |
+
for example in tqdm(val_dataset, desc="Processing Images"):
|
95 |
+
image_path = example["file_name"]
|
96 |
+
|
97 |
+
prompt = []
|
98 |
+
for caption in example["captions"]:
|
99 |
+
if isinstance(caption, str):
|
100 |
+
prompt.append(caption)
|
101 |
+
elif isinstance(caption, (list, np.ndarray)):
|
102 |
+
# take a random caption if there are multiple
|
103 |
+
prompt.append(caption[0])
|
104 |
+
else:
|
105 |
+
raise ValueError(
|
106 |
+
f"Caption column `captions` should contain either strings or lists of strings."
|
107 |
+
)
|
108 |
+
|
109 |
+
# Generate image
|
110 |
+
ground_truth_image = load_image(image_path).resize((512, 512))
|
111 |
+
control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512))
|
112 |
+
image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
|
113 |
+
|
114 |
+
# Apply color mapping
|
115 |
+
image = apply_color(ground_truth_image, image)
|
116 |
+
|
117 |
+
# Concatenate images into a row
|
118 |
+
row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image)))
|
119 |
+
row_image = Image.fromarray(row_image)
|
120 |
+
|
121 |
+
# Save row image in the compare folder
|
122 |
+
compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}")
|
123 |
+
row_image.save(compare_output_path)
|
124 |
+
|
125 |
+
# Save colorized image in the colorized folder
|
126 |
+
colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}")
|
127 |
+
image.save(colorized_output_path)
|
128 |
+
|
129 |
+
# Increment processed images counter
|
130 |
+
processed_images += 1
|
131 |
+
|
132 |
+
# Record end time
|
133 |
+
end_time = time.time()
|
134 |
+
|
135 |
+
# Calculate total time taken
|
136 |
+
total_time = end_time - start_time
|
137 |
+
|
138 |
+
# Calculate FPS
|
139 |
+
fps = processed_images / total_time
|
140 |
+
|
141 |
+
print("All images processed.")
|
142 |
+
print(f"Total time taken: {total_time:.2f} seconds")
|
143 |
+
print(f"FPS: {fps:.2f}")
|
144 |
+
|
145 |
+
# Entry point of the script
|
146 |
+
if __name__ == "__main__":
|
147 |
+
args = parse_args()
|
148 |
+
main(args)
|
eval_controlnet.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define default values for parameters
|
2 |
+
|
3 |
+
# # sdv2 with BCE loss
|
4 |
+
# MODEL_DIR="sd_v2_caption_bce_output/checkpoint-22500"
|
5 |
+
# MODEL_ID="stabilityai/stable-diffusion-2-base"
|
6 |
+
# DATASET="nickpai/coco2017-colorization"
|
7 |
+
# REVISION="main"
|
8 |
+
|
9 |
+
# sdv2 with kl loss
|
10 |
+
MODEL_DIR="sd_v2_caption_kl_output/checkpoint-22500"
|
11 |
+
MODEL_ID="stabilityai/stable-diffusion-2-base"
|
12 |
+
DATASET="nickpai/coco2017-colorization"
|
13 |
+
REVISION="main"
|
14 |
+
|
15 |
+
accelerate launch eval_controlnet.py \
|
16 |
+
--model_dir=$MODEL_DIR \
|
17 |
+
--model_id=$MODEL_ID \
|
18 |
+
--dataset=$DATASET \
|
19 |
+
--revision=$REVISION
|
eval_controlnet_sdxl_light.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import shutil
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
from datasets import load_dataset
|
11 |
+
from accelerate import Accelerator
|
12 |
+
from diffusers.utils import load_image
|
13 |
+
from diffusers import (
|
14 |
+
AutoencoderKL,
|
15 |
+
StableDiffusionXLControlNetPipeline,
|
16 |
+
ControlNetModel,
|
17 |
+
UNet2DConditionModel,
|
18 |
+
)
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
|
22 |
+
# Define the function to parse arguments
|
23 |
+
def parse_args(input_args=None):
|
24 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"--pretrained_model_name_or_path",
|
28 |
+
type=str,
|
29 |
+
default=None,
|
30 |
+
required=True,
|
31 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--pretrained_vae_model_name_or_path",
|
35 |
+
type=str,
|
36 |
+
default=None,
|
37 |
+
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--controlnet_model_name_or_path",
|
41 |
+
type=str,
|
42 |
+
default=None,
|
43 |
+
required=True,
|
44 |
+
help="Path to pretrained controlnet model.",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--output_dir",
|
48 |
+
type=str,
|
49 |
+
default=None,
|
50 |
+
required=True,
|
51 |
+
help="Path to output results.",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--dataset",
|
55 |
+
type=str,
|
56 |
+
default="nickpai/coco2017-colorization",
|
57 |
+
help="Dataset used"
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--dataset_revision",
|
61 |
+
type=str,
|
62 |
+
default="caption-free",
|
63 |
+
choices=["main", "caption-free", "custom-caption"],
|
64 |
+
help="Revision option (main/caption-free/custom-caption)"
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--mixed_precision",
|
68 |
+
type=str,
|
69 |
+
default=None,
|
70 |
+
choices=["no", "fp16", "bf16"],
|
71 |
+
help=(
|
72 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
73 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
74 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
75 |
+
),
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--variant",
|
79 |
+
type=str,
|
80 |
+
default=None,
|
81 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--revision",
|
85 |
+
type=str,
|
86 |
+
default=None,
|
87 |
+
required=False,
|
88 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--num_inference_steps",
|
92 |
+
type=int,
|
93 |
+
default=8,
|
94 |
+
help="1-step, 2-step, 4-step, or 8-step distilled models"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--repo",
|
98 |
+
type=str,
|
99 |
+
default="ByteDance/SDXL-Lightning",
|
100 |
+
required=True,
|
101 |
+
help="Repository from huggingface.co",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--ckpt",
|
105 |
+
type=str,
|
106 |
+
default="sdxl_lightning_4step_unet.safetensors",
|
107 |
+
required=True,
|
108 |
+
help="Available checkpoints from the repository",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--negative_prompt",
|
112 |
+
action="store_true",
|
113 |
+
help="The prompt or prompts not to guide the image generation",
|
114 |
+
)
|
115 |
+
|
116 |
+
if input_args is not None:
|
117 |
+
args = parser.parse_args(input_args)
|
118 |
+
else:
|
119 |
+
args = parser.parse_args()
|
120 |
+
|
121 |
+
return args
|
122 |
+
|
123 |
+
def apply_color(image, color_map):
|
124 |
+
# Convert input images to LAB color space
|
125 |
+
image_lab = image.convert('LAB')
|
126 |
+
color_map_lab = color_map.convert('LAB')
|
127 |
+
|
128 |
+
# Split LAB channels
|
129 |
+
l, a, b = image_lab.split()
|
130 |
+
_, a_map, b_map = color_map_lab.split()
|
131 |
+
|
132 |
+
# Merge LAB channels with color map
|
133 |
+
merged_lab = Image.merge('LAB', (l, a_map, b_map))
|
134 |
+
|
135 |
+
# Convert merged LAB image back to RGB color space
|
136 |
+
result_rgb = merged_lab.convert('RGB')
|
137 |
+
|
138 |
+
return result_rgb
|
139 |
+
|
140 |
+
def main(args):
|
141 |
+
generator = torch.manual_seed(0)
|
142 |
+
|
143 |
+
# Path to the eval_results folder
|
144 |
+
eval_results_folder = os.path.join(args.output_dir, "results")
|
145 |
+
|
146 |
+
# Remove eval_results folder if it exists
|
147 |
+
if os.path.exists(eval_results_folder):
|
148 |
+
shutil.rmtree(eval_results_folder)
|
149 |
+
|
150 |
+
# Create directory for eval_results
|
151 |
+
os.makedirs(eval_results_folder)
|
152 |
+
|
153 |
+
# Create subfolders for compare and colorized images
|
154 |
+
compare_folder = os.path.join(eval_results_folder, "compare")
|
155 |
+
colorized_folder = os.path.join(eval_results_folder, "colorized")
|
156 |
+
os.makedirs(compare_folder)
|
157 |
+
os.makedirs(colorized_folder)
|
158 |
+
|
159 |
+
# Load the validation split of the colorization dataset
|
160 |
+
val_dataset = load_dataset(args.dataset, split="validation", revision=args.dataset_revision)
|
161 |
+
|
162 |
+
accelerator = Accelerator(
|
163 |
+
mixed_precision=args.mixed_precision,
|
164 |
+
)
|
165 |
+
|
166 |
+
weight_dtype = torch.float32
|
167 |
+
if accelerator.mixed_precision == "fp16":
|
168 |
+
weight_dtype = torch.float16
|
169 |
+
elif accelerator.mixed_precision == "bf16":
|
170 |
+
weight_dtype = torch.bfloat16
|
171 |
+
|
172 |
+
vae_path = (
|
173 |
+
args.pretrained_model_name_or_path
|
174 |
+
if args.pretrained_vae_model_name_or_path is None
|
175 |
+
else args.pretrained_vae_model_name_or_path
|
176 |
+
)
|
177 |
+
vae = AutoencoderKL.from_pretrained(
|
178 |
+
vae_path,
|
179 |
+
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
180 |
+
revision=args.revision,
|
181 |
+
variant=args.variant,
|
182 |
+
)
|
183 |
+
unet = UNet2DConditionModel.from_config(
|
184 |
+
args.pretrained_model_name_or_path,
|
185 |
+
subfolder="unet",
|
186 |
+
revision=args.revision,
|
187 |
+
variant=args.variant,
|
188 |
+
)
|
189 |
+
unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt)))
|
190 |
+
|
191 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
192 |
+
# The VAE is in float32 to avoid NaN losses.
|
193 |
+
if args.pretrained_vae_model_name_or_path is not None:
|
194 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
195 |
+
else:
|
196 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
197 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
198 |
+
|
199 |
+
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=weight_dtype)
|
200 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
201 |
+
args.pretrained_model_name_or_path,
|
202 |
+
vae=vae,
|
203 |
+
unet=unet,
|
204 |
+
controlnet=controlnet,
|
205 |
+
)
|
206 |
+
pipe.to(accelerator.device, dtype=weight_dtype)
|
207 |
+
|
208 |
+
# Prepare everything with our `accelerator`.
|
209 |
+
pipe, val_dataset = accelerator.prepare(pipe, val_dataset)
|
210 |
+
|
211 |
+
pipe.safety_checker = None
|
212 |
+
|
213 |
+
# Counter for processed images
|
214 |
+
processed_images = 0
|
215 |
+
|
216 |
+
# Record start time
|
217 |
+
start_time = time.time()
|
218 |
+
|
219 |
+
# Iterate through the validation dataset
|
220 |
+
for example in tqdm(val_dataset, desc="Processing Images"):
|
221 |
+
image_path = example["file_name"]
|
222 |
+
|
223 |
+
prompt = []
|
224 |
+
for caption in example["captions"]:
|
225 |
+
if isinstance(caption, str):
|
226 |
+
prompt.append(caption)
|
227 |
+
elif isinstance(caption, (list, np.ndarray)):
|
228 |
+
# take a random caption if there are multiple
|
229 |
+
prompt.append(caption[0])
|
230 |
+
else:
|
231 |
+
raise ValueError(
|
232 |
+
f"Caption column `captions` should contain either strings or lists of strings."
|
233 |
+
)
|
234 |
+
|
235 |
+
negative_prompt = None
|
236 |
+
if args.negative_prompt:
|
237 |
+
negative_prompt = [
|
238 |
+
"low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate"
|
239 |
+
]
|
240 |
+
|
241 |
+
# Generate image
|
242 |
+
ground_truth_image = load_image(image_path).resize((512, 512))
|
243 |
+
control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512))
|
244 |
+
image = pipe(prompt=prompt,
|
245 |
+
negative_prompt=negative_prompt,
|
246 |
+
num_inference_steps=args.num_inference_steps,
|
247 |
+
generator=generator,
|
248 |
+
image=control_image).images[0]
|
249 |
+
|
250 |
+
# Apply color mapping
|
251 |
+
image = apply_color(ground_truth_image, image)
|
252 |
+
|
253 |
+
# Concatenate images into a row
|
254 |
+
row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image)))
|
255 |
+
row_image = Image.fromarray(row_image)
|
256 |
+
|
257 |
+
# Save row image in the compare folder
|
258 |
+
compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}")
|
259 |
+
row_image.save(compare_output_path)
|
260 |
+
|
261 |
+
# Save colorized image in the colorized folder
|
262 |
+
colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}")
|
263 |
+
image.save(colorized_output_path)
|
264 |
+
|
265 |
+
# Increment processed images counter
|
266 |
+
processed_images += 1
|
267 |
+
|
268 |
+
# Record end time
|
269 |
+
end_time = time.time()
|
270 |
+
|
271 |
+
# Calculate total time taken
|
272 |
+
total_time = end_time - start_time
|
273 |
+
|
274 |
+
# Calculate FPS
|
275 |
+
fps = processed_images / total_time
|
276 |
+
|
277 |
+
print("All images processed.")
|
278 |
+
print(f"Total time taken: {total_time:.2f} seconds")
|
279 |
+
print(f"FPS: {fps:.2f}")
|
280 |
+
|
281 |
+
# Entry point of the script
|
282 |
+
if __name__ == "__main__":
|
283 |
+
args = parse_args()
|
284 |
+
main(args)
|
eval_controlnet_sdxl_light.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define default values for parameters
|
2 |
+
|
3 |
+
# # sdxl light without negative prompt
|
4 |
+
# export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
|
5 |
+
# export REPO="ByteDance/SDXL-Lightning"
|
6 |
+
# export INFERENCE_STEP=8
|
7 |
+
# export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step
|
8 |
+
# export CONTROLNET_MODEL="sdxl_light_custom_caption_output/checkpoint-12500/controlnet"
|
9 |
+
# export DATASET="nickpai/coco2017-colorization"
|
10 |
+
# export DATSET_REVISION="custom-caption"
|
11 |
+
# export OUTPUT_DIR="sdxl_light_custom_caption_output/checkpoint-12500"
|
12 |
+
|
13 |
+
# accelerate launch eval_controlnet_sdxl_light.py \
|
14 |
+
# --pretrained_model_name_or_path=$BASE_MODEL \
|
15 |
+
# --repo=$REPO \
|
16 |
+
# --ckpt=$CKPT \
|
17 |
+
# --num_inference_steps=$INFERENCE_STEP \
|
18 |
+
# --controlnet_model_name_or_path=$CONTROLNET_MODEL \
|
19 |
+
# --dataset=$DATASET \
|
20 |
+
# --dataset_revision=$DATSET_REVISION \
|
21 |
+
# --mixed_precision="fp16" \
|
22 |
+
# --output_dir=$OUTPUT_DIR
|
23 |
+
|
24 |
+
# sdxl light with negative prompt
|
25 |
+
export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
|
26 |
+
export REPO="ByteDance/SDXL-Lightning"
|
27 |
+
export INFERENCE_STEP=8
|
28 |
+
export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step
|
29 |
+
export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-22500/controlnet"
|
30 |
+
export DATASET="nickpai/coco2017-colorization"
|
31 |
+
export DATSET_REVISION="custom-caption"
|
32 |
+
export OUTPUT_DIR="sdxl_light_caption_output/checkpoint-22500"
|
33 |
+
|
34 |
+
accelerate launch eval_controlnet_sdxl_light.py \
|
35 |
+
--pretrained_model_name_or_path=$BASE_MODEL \
|
36 |
+
--repo=$REPO \
|
37 |
+
--ckpt=$CKPT \
|
38 |
+
--num_inference_steps=$INFERENCE_STEP \
|
39 |
+
--controlnet_model_name_or_path=$CONTROLNET_MODEL \
|
40 |
+
--dataset=$DATASET \
|
41 |
+
--dataset_revision=$DATSET_REVISION \
|
42 |
+
--mixed_precision="fp16" \
|
43 |
+
--output_dir=$OUTPUT_DIR \
|
44 |
+
--negative_prompt
|
eval_controlnet_sdxl_light_single.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import PIL
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from typing import Optional, Union
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from diffusers import (
|
10 |
+
AutoencoderKL,
|
11 |
+
StableDiffusionXLControlNetPipeline,
|
12 |
+
ControlNetModel,
|
13 |
+
UNet2DConditionModel,
|
14 |
+
)
|
15 |
+
from transformers import (
|
16 |
+
BlipProcessor, BlipForConditionalGeneration,
|
17 |
+
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
18 |
+
)
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
+
from safetensors.torch import load_file
|
21 |
+
|
22 |
+
# Define the function to parse arguments
|
23 |
+
def parse_args(input_args=None):
|
24 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.")
|
25 |
+
parser.add_argument(
|
26 |
+
"--image_path",
|
27 |
+
type=str,
|
28 |
+
default="example/legacy_images/Hollywood-Sign.jpg",
|
29 |
+
required=True,
|
30 |
+
help="Path to the image",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--pretrained_model_name_or_path",
|
34 |
+
type=str,
|
35 |
+
default=None,
|
36 |
+
required=True,
|
37 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--pretrained_vae_model_name_or_path",
|
41 |
+
type=str,
|
42 |
+
default=None,
|
43 |
+
help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--controlnet_model_name_or_path",
|
47 |
+
type=str,
|
48 |
+
default=None,
|
49 |
+
required=True,
|
50 |
+
help="Path to pretrained controlnet model.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--caption_model_name",
|
54 |
+
type=str,
|
55 |
+
default="blip-image-captioning-large",
|
56 |
+
choices=["blip-image-captioning-large", "blip-image-captioning-base"],
|
57 |
+
help="Path to pretrained controlnet model.",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--mixed_precision",
|
61 |
+
type=str,
|
62 |
+
default=None,
|
63 |
+
choices=["no", "fp16", "bf16"],
|
64 |
+
help=(
|
65 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
66 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
67 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
68 |
+
),
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--variant",
|
72 |
+
type=str,
|
73 |
+
default=None,
|
74 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--revision",
|
78 |
+
type=str,
|
79 |
+
default=None,
|
80 |
+
required=False,
|
81 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--num_inference_steps",
|
85 |
+
type=int,
|
86 |
+
default=8,
|
87 |
+
help="1-step, 2-step, 4-step, or 8-step distilled models"
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--repo",
|
91 |
+
type=str,
|
92 |
+
default="ByteDance/SDXL-Lightning",
|
93 |
+
required=True,
|
94 |
+
help="Repository from huggingface.co",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--ckpt",
|
98 |
+
type=str,
|
99 |
+
default="sdxl_lightning_4step_unet.safetensors",
|
100 |
+
required=True,
|
101 |
+
help="Available checkpoints from the repository",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--seed",
|
105 |
+
type=int,
|
106 |
+
default=123,
|
107 |
+
help="Random seeds"
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--positive_prompt",
|
111 |
+
type=str,
|
112 |
+
help="Text for positive prompt",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--negative_prompt",
|
116 |
+
type=str,
|
117 |
+
default="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate",
|
118 |
+
help="Text for negative prompt",
|
119 |
+
)
|
120 |
+
|
121 |
+
if input_args is not None:
|
122 |
+
args = parser.parse_args(input_args)
|
123 |
+
else:
|
124 |
+
args = parser.parse_args()
|
125 |
+
|
126 |
+
return args
|
127 |
+
|
128 |
+
def apply_color(image, color_map):
|
129 |
+
# Convert input images to LAB color space
|
130 |
+
image_lab = image.convert('LAB')
|
131 |
+
color_map_lab = color_map.convert('LAB')
|
132 |
+
|
133 |
+
# Split LAB channels
|
134 |
+
l, a, b = image_lab.split()
|
135 |
+
_, a_map, b_map = color_map_lab.split()
|
136 |
+
|
137 |
+
# Merge LAB channels with color map
|
138 |
+
merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
|
139 |
+
|
140 |
+
# Convert merged LAB image back to RGB color space
|
141 |
+
result_rgb = merged_lab.convert('RGB')
|
142 |
+
|
143 |
+
return result_rgb
|
144 |
+
|
145 |
+
def remove_unlikely_words(prompt: str) -> str:
|
146 |
+
"""
|
147 |
+
Removes unlikely words from a prompt.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
prompt: The text prompt to be cleaned.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
The cleaned prompt with unlikely words removed.
|
154 |
+
"""
|
155 |
+
unlikely_words = []
|
156 |
+
|
157 |
+
a1_list = [f'{i}s' for i in range(1900, 2000)]
|
158 |
+
a2_list = [f'{i}' for i in range(1900, 2000)]
|
159 |
+
a3_list = [f'year {i}' for i in range(1900, 2000)]
|
160 |
+
a4_list = [f'circa {i}' for i in range(1900, 2000)]
|
161 |
+
b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list]
|
162 |
+
b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
|
163 |
+
b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
|
164 |
+
b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
|
165 |
+
|
166 |
+
words_list = [
|
167 |
+
"black and white,", "black and white", "black & white,", "black & white", "circa",
|
168 |
+
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
|
169 |
+
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
|
170 |
+
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
|
171 |
+
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
|
172 |
+
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
|
173 |
+
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
|
174 |
+
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
|
175 |
+
"black-and-white photo,", "black-and-white photo", "black - and - white photography",
|
176 |
+
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
|
177 |
+
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
|
178 |
+
"black - and - white photograph,", "black - and - white photograph", "black on white,",
|
179 |
+
"black on white", "black-and-white", "historical image,", "historical picture,",
|
180 |
+
"historical photo,", "historical photograph,", "archival photo,", "taken in the early",
|
181 |
+
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
|
182 |
+
"historical photo", "historical setting,",
|
183 |
+
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
|
184 |
+
"taken in", "shot on leica", "shot on leica sl2", "sl2",
|
185 |
+
"taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting",
|
186 |
+
"overcast day", "overcast weather", "slight overcast", "overcast",
|
187 |
+
"picture taken in", "photo taken in",
|
188 |
+
", photo", ", photo", ", photo", ", photo", ", photograph",
|
189 |
+
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
|
190 |
+
]
|
191 |
+
|
192 |
+
unlikely_words.extend(a1_list)
|
193 |
+
unlikely_words.extend(a2_list)
|
194 |
+
unlikely_words.extend(a3_list)
|
195 |
+
unlikely_words.extend(a4_list)
|
196 |
+
unlikely_words.extend(b1_list)
|
197 |
+
unlikely_words.extend(b2_list)
|
198 |
+
unlikely_words.extend(b3_list)
|
199 |
+
unlikely_words.extend(b4_list)
|
200 |
+
unlikely_words.extend(words_list)
|
201 |
+
|
202 |
+
for word in unlikely_words:
|
203 |
+
prompt = prompt.replace(word, "")
|
204 |
+
return prompt
|
205 |
+
|
206 |
+
def blip_image_captioning(image: PIL.Image.Image,
|
207 |
+
model_backbone: str,
|
208 |
+
weight_dtype: type,
|
209 |
+
device: str,
|
210 |
+
conditional: bool) -> str:
|
211 |
+
# https://huggingface.co/Salesforce/blip-image-captioning-large
|
212 |
+
# https://huggingface.co/Salesforce/blip-image-captioning-base
|
213 |
+
if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type
|
214 |
+
weight_dtype = torch.float16
|
215 |
+
|
216 |
+
processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}")
|
217 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
218 |
+
f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device)
|
219 |
+
|
220 |
+
valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"]
|
221 |
+
if model_backbone not in valid_backbones:
|
222 |
+
raise ValueError(f"Invalid model backbone '{model_backbone}'. \
|
223 |
+
Valid options are: {', '.join(valid_backbones)}")
|
224 |
+
|
225 |
+
if conditional:
|
226 |
+
text = "a photography of"
|
227 |
+
inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype)
|
228 |
+
else:
|
229 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
230 |
+
out = model.generate(**inputs)
|
231 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
232 |
+
return caption
|
233 |
+
|
234 |
+
import matplotlib.pyplot as plt
|
235 |
+
|
236 |
+
def display_images(input_image, output_image, ground_truth):
|
237 |
+
"""
|
238 |
+
Displays a grid of input, output, ground truth images with a caption at the bottom.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
input_image: A grayscale image as a NumPy array.
|
242 |
+
output_image: A grayscale image (result) as a NumPy array.
|
243 |
+
ground_truth: A grayscale image (ground truth) as a NumPy array.
|
244 |
+
"""
|
245 |
+
fig, axes = plt.subplots(1, 3, figsize=(20, 8))
|
246 |
+
|
247 |
+
axes[0].imshow(input_image, cmap='gray')
|
248 |
+
axes[0].set_title('Input')
|
249 |
+
axes[0].axis('off')
|
250 |
+
|
251 |
+
axes[1].imshow(output_image)
|
252 |
+
axes[1].set_title('Output')
|
253 |
+
axes[1].axis('off')
|
254 |
+
|
255 |
+
axes[2].imshow(ground_truth)
|
256 |
+
axes[2].set_title('Ground Truth')
|
257 |
+
axes[2].axis('off')
|
258 |
+
|
259 |
+
plt.tight_layout()
|
260 |
+
plt.show()
|
261 |
+
|
262 |
+
# Define a function to process the image with the loaded model
|
263 |
+
def process_image(image_path: str,
|
264 |
+
controlnet_model_name_or_path: str,
|
265 |
+
caption_model_name: str,
|
266 |
+
positive_prompt: Optional[str],
|
267 |
+
negative_prompt: Optional[str],
|
268 |
+
seed: int,
|
269 |
+
num_inference_steps: int,
|
270 |
+
mixed_precision: str,
|
271 |
+
pretrained_model_name_or_path: str,
|
272 |
+
pretrained_vae_model_name_or_path: Optional[str],
|
273 |
+
revision: Optional[str],
|
274 |
+
variant: Optional[str],
|
275 |
+
repo: str,
|
276 |
+
ckpt: str,) -> PIL.Image.Image:
|
277 |
+
# Seed
|
278 |
+
generator = torch.manual_seed(seed)
|
279 |
+
|
280 |
+
# Accelerator Setting
|
281 |
+
accelerator = Accelerator(
|
282 |
+
mixed_precision=mixed_precision,
|
283 |
+
)
|
284 |
+
|
285 |
+
weight_dtype = torch.float32
|
286 |
+
if accelerator.mixed_precision == "fp16":
|
287 |
+
weight_dtype = torch.float16
|
288 |
+
elif accelerator.mixed_precision == "bf16":
|
289 |
+
weight_dtype = torch.bfloat16
|
290 |
+
|
291 |
+
vae_path = (
|
292 |
+
pretrained_model_name_or_path
|
293 |
+
if pretrained_vae_model_name_or_path is None
|
294 |
+
else pretrained_vae_model_name_or_path
|
295 |
+
)
|
296 |
+
vae = AutoencoderKL.from_pretrained(
|
297 |
+
vae_path,
|
298 |
+
subfolder="vae" if pretrained_vae_model_name_or_path is None else None,
|
299 |
+
revision=revision,
|
300 |
+
variant=variant,
|
301 |
+
)
|
302 |
+
unet = UNet2DConditionModel.from_config(
|
303 |
+
pretrained_model_name_or_path,
|
304 |
+
subfolder="unet",
|
305 |
+
revision=revision,
|
306 |
+
variant=variant,
|
307 |
+
)
|
308 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
|
309 |
+
|
310 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
311 |
+
# The VAE is in float32 to avoid NaN losses.
|
312 |
+
if pretrained_vae_model_name_or_path is not None:
|
313 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
314 |
+
else:
|
315 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
316 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
317 |
+
|
318 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype)
|
319 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
320 |
+
pretrained_model_name_or_path,
|
321 |
+
vae=vae,
|
322 |
+
unet=unet,
|
323 |
+
controlnet=controlnet,
|
324 |
+
)
|
325 |
+
pipe.to(accelerator.device, dtype=weight_dtype)
|
326 |
+
|
327 |
+
image = PIL.Image.open(image_path)
|
328 |
+
|
329 |
+
# Prepare everything with our `accelerator`.
|
330 |
+
pipe, image = accelerator.prepare(pipe, image)
|
331 |
+
pipe.safety_checker = None
|
332 |
+
|
333 |
+
# Convert image into grayscale
|
334 |
+
original_size = image.size
|
335 |
+
control_image = image.convert("L").convert("RGB").resize((512, 512))
|
336 |
+
|
337 |
+
# Image captioning
|
338 |
+
if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base":
|
339 |
+
caption = blip_image_captioning(control_image, caption_model_name,
|
340 |
+
weight_dtype, accelerator.device, conditional=True)
|
341 |
+
# elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k":
|
342 |
+
# caption = clip_image_captioning(control_image, caption_model_name, accelerator.device)
|
343 |
+
# elif caption_model_name == "vit-gpt2-image-captioning":
|
344 |
+
# caption = vit_gpt2_image_captioning(control_image, accelerator.device)
|
345 |
+
caption = remove_unlikely_words(caption)
|
346 |
+
|
347 |
+
print("================================================================")
|
348 |
+
print(f"Positive prompt: \n>>> {positive_prompt}")
|
349 |
+
print(f"Negative prompt: \n>>> {negative_prompt}")
|
350 |
+
print(f"Caption results: \n>>> {caption}")
|
351 |
+
print("================================================================")
|
352 |
+
|
353 |
+
# Combine positive prompt and captioning result
|
354 |
+
prompt = [positive_prompt + ", " + caption]
|
355 |
+
|
356 |
+
# Image colorization
|
357 |
+
image = pipe(prompt=prompt,
|
358 |
+
negative_prompt=negative_prompt,
|
359 |
+
num_inference_steps=num_inference_steps,
|
360 |
+
generator=generator,
|
361 |
+
image=control_image).images[0]
|
362 |
+
|
363 |
+
# Apply color mapping
|
364 |
+
result_image = apply_color(control_image, image)
|
365 |
+
result_image = result_image.resize(original_size)
|
366 |
+
return result_image, caption
|
367 |
+
|
368 |
+
def main(args):
|
369 |
+
output_image, output_caption = process_image(image_path=args.image_path,
|
370 |
+
controlnet_model_name_or_path=args.controlnet_model_name_or_path,
|
371 |
+
caption_model_name=args.caption_model_name,
|
372 |
+
positive_prompt=args.positive_prompt,
|
373 |
+
negative_prompt=args.negative_prompt,
|
374 |
+
seed=args.seed,
|
375 |
+
num_inference_steps=args.num_inference_steps,
|
376 |
+
mixed_precision=args.mixed_precision,
|
377 |
+
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
|
378 |
+
pretrained_vae_model_name_or_path=args.pretrained_vae_model_name_or_path,
|
379 |
+
revision=args.revision,
|
380 |
+
variant=args.variant,
|
381 |
+
repo=args.repo,
|
382 |
+
ckpt=args.ckpt,)
|
383 |
+
input_image = PIL.Image.open(args.image_path)
|
384 |
+
display_images(input_image.convert("L"), output_image, input_image)
|
385 |
+
return output_image, output_caption
|
386 |
+
|
387 |
+
# Entry point of the script
|
388 |
+
if __name__ == "__main__":
|
389 |
+
args = parse_args()
|
390 |
+
main(args)
|
eval_controlnet_sdxl_light_single.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# sdxl light for single image
|
2 |
+
export BASE_MODEL="stabilityai/stable-diffusion-xl-base-1.0"
|
3 |
+
export REPO="ByteDance/SDXL-Lightning"
|
4 |
+
export INFERENCE_STEP=8
|
5 |
+
export CKPT="sdxl_lightning_8step_unet.safetensors" # caution!!! ckpt's "N"step must match with inference_step
|
6 |
+
export CONTROLNET_MODEL="sdxl_light_caption_output/checkpoint-30000/controlnet"
|
7 |
+
export CAPTION_MODEL="blip-image-captioning-large"
|
8 |
+
export IMAGE_PATH="example/legacy_images/Hollywood-Sign.jpg"
|
9 |
+
# export POSITIVE_PROMPT="blue shirt"
|
10 |
+
|
11 |
+
accelerate launch eval_controlnet_sdxl_light_single.py \
|
12 |
+
--pretrained_model_name_or_path=$BASE_MODEL \
|
13 |
+
--repo=$REPO \
|
14 |
+
--ckpt=$CKPT \
|
15 |
+
--num_inference_steps=$INFERENCE_STEP \
|
16 |
+
--controlnet_model_name_or_path=$CONTROLNET_MODEL \
|
17 |
+
--caption_model_name=$CAPTION_MODEL \
|
18 |
+
--mixed_precision="fp16" \
|
19 |
+
--image_path=$IMAGE_PATH \
|
20 |
+
--positive_prompt="red car"
|
example/UUColor_results/Hollywood-Sign.jpeg
ADDED
example/legacy_images/Big-Ben-vintage.jpg
ADDED
example/legacy_images/Central-Park.jpg
ADDED
example/legacy_images/Hollywood-Sign.jpg
ADDED
example/legacy_images/Little-Mermaid.jpg
ADDED
example/legacy_images/Migrant-Mother.jpg
ADDED
example/legacy_images/Mount-Everest.jpg
ADDED
example/legacy_images/Tower-of-Pisa.jpg
ADDED
example/legacy_images/Wasatch-Mountains-Summit-County-Utah.jpg
ADDED
gradio_ui.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import torch
|
3 |
+
import subprocess
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from typing import Optional
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from diffusers import (
|
9 |
+
AutoencoderKL,
|
10 |
+
StableDiffusionXLControlNetPipeline,
|
11 |
+
ControlNetModel,
|
12 |
+
UNet2DConditionModel,
|
13 |
+
)
|
14 |
+
from transformers import (
|
15 |
+
BlipProcessor, BlipForConditionalGeneration,
|
16 |
+
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
17 |
+
)
|
18 |
+
from huggingface_hub import hf_hub_download
|
19 |
+
from safetensors.torch import load_file
|
20 |
+
from clip_interrogator import Interrogator, Config, list_clip_models
|
21 |
+
|
22 |
+
def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
|
23 |
+
# Convert input images to LAB color space
|
24 |
+
image_lab = image.convert('LAB')
|
25 |
+
color_map_lab = color_map.convert('LAB')
|
26 |
+
|
27 |
+
# Split LAB channels
|
28 |
+
l, a , b = image_lab.split()
|
29 |
+
_, a_map, b_map = color_map_lab.split()
|
30 |
+
|
31 |
+
# Merge LAB channels with color map
|
32 |
+
merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
|
33 |
+
|
34 |
+
# Convert merged LAB image back to RGB color space
|
35 |
+
result_rgb = merged_lab.convert('RGB')
|
36 |
+
return result_rgb
|
37 |
+
|
38 |
+
def remove_unlikely_words(prompt: str) -> str:
|
39 |
+
"""
|
40 |
+
Removes unlikely words from a prompt.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
prompt: The text prompt to be cleaned.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
The cleaned prompt with unlikely words removed.
|
47 |
+
"""
|
48 |
+
unlikely_words = []
|
49 |
+
|
50 |
+
a1_list = [f'{i}s' for i in range(1900, 2000)]
|
51 |
+
a2_list = [f'{i}' for i in range(1900, 2000)]
|
52 |
+
a3_list = [f'year {i}' for i in range(1900, 2000)]
|
53 |
+
a4_list = [f'circa {i}' for i in range(1900, 2000)]
|
54 |
+
b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list]
|
55 |
+
b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
|
56 |
+
b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
|
57 |
+
b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
|
58 |
+
|
59 |
+
words_list = [
|
60 |
+
"black and white,", "black and white", "black & white,", "black & white", "circa",
|
61 |
+
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
|
62 |
+
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
|
63 |
+
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
|
64 |
+
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
|
65 |
+
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
|
66 |
+
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
|
67 |
+
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
|
68 |
+
"black-and-white photo,", "black-and-white photo", "black - and - white photography",
|
69 |
+
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
|
70 |
+
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
|
71 |
+
"black - and - white photograph,", "black - and - white photograph", "black on white,",
|
72 |
+
"black on white", "black-and-white", "historical image,", "historical picture,",
|
73 |
+
"historical photo,", "historical photograph,", "archival photo,", "taken in the early",
|
74 |
+
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
|
75 |
+
"historical photo", "historical setting,",
|
76 |
+
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
|
77 |
+
"taken in", "shot on leica", "shot on leica sl2", "sl2",
|
78 |
+
"taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting",
|
79 |
+
"overcast day", "overcast weather", "slight overcast", "overcast",
|
80 |
+
"picture taken in", "photo taken in",
|
81 |
+
", photo", ", photo", ", photo", ", photo", ", photograph",
|
82 |
+
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
|
83 |
+
]
|
84 |
+
|
85 |
+
unlikely_words.extend(a1_list)
|
86 |
+
unlikely_words.extend(a2_list)
|
87 |
+
unlikely_words.extend(a3_list)
|
88 |
+
unlikely_words.extend(a4_list)
|
89 |
+
unlikely_words.extend(b1_list)
|
90 |
+
unlikely_words.extend(b2_list)
|
91 |
+
unlikely_words.extend(b3_list)
|
92 |
+
unlikely_words.extend(b4_list)
|
93 |
+
unlikely_words.extend(words_list)
|
94 |
+
|
95 |
+
for word in unlikely_words:
|
96 |
+
prompt = prompt.replace(word, "")
|
97 |
+
return prompt
|
98 |
+
|
99 |
+
def blip_image_captioning(image: PIL.Image.Image,
|
100 |
+
model_backbone: str,
|
101 |
+
weight_dtype: type,
|
102 |
+
device: str,
|
103 |
+
conditional: bool) -> str:
|
104 |
+
# https://huggingface.co/Salesforce/blip-image-captioning-large
|
105 |
+
# https://huggingface.co/Salesforce/blip-image-captioning-base
|
106 |
+
if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type
|
107 |
+
weight_dtype = torch.float16
|
108 |
+
|
109 |
+
processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}")
|
110 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
111 |
+
f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device)
|
112 |
+
|
113 |
+
valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"]
|
114 |
+
if model_backbone not in valid_backbones:
|
115 |
+
raise ValueError(f"Invalid model backbone '{model_backbone}'. \
|
116 |
+
Valid options are: {', '.join(valid_backbones)}")
|
117 |
+
|
118 |
+
if conditional:
|
119 |
+
text = "a photography of"
|
120 |
+
inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype)
|
121 |
+
else:
|
122 |
+
inputs = processor(image, return_tensors="pt").to(device)
|
123 |
+
out = model.generate(**inputs)
|
124 |
+
caption = processor.decode(out[0], skip_special_tokens=True)
|
125 |
+
return caption
|
126 |
+
|
127 |
+
# def vit_gpt2_image_captioning(image: PIL.Image.Image, device: str) -> str:
|
128 |
+
# # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
|
129 |
+
# model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
|
130 |
+
# feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
131 |
+
# tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
132 |
+
|
133 |
+
# max_length = 16
|
134 |
+
# num_beams = 4
|
135 |
+
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
136 |
+
|
137 |
+
# pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
138 |
+
# pixel_values = pixel_values.to(device)
|
139 |
+
|
140 |
+
# output_ids = model.generate(pixel_values, **gen_kwargs)
|
141 |
+
|
142 |
+
# preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
143 |
+
# caption = [pred.strip() for pred in preds]
|
144 |
+
|
145 |
+
# return caption[0]
|
146 |
+
|
147 |
+
# def clip_image_captioning(image: PIL.Image.Image,
|
148 |
+
# clip_model_name: str,
|
149 |
+
# device: str) -> str:
|
150 |
+
# # validate clip model name
|
151 |
+
# models = list_clip_models()
|
152 |
+
# if clip_model_name not in models:
|
153 |
+
# raise ValueError(f"Could not find CLIP model {clip_model_name}! \
|
154 |
+
# Available models: {models}")
|
155 |
+
# config = Config(device=device, clip_model_name=clip_model_name)
|
156 |
+
# config.apply_low_vram_defaults()
|
157 |
+
# ci = Interrogator(config)
|
158 |
+
# caption = ci.interrogate(image)
|
159 |
+
# return caption
|
160 |
+
|
161 |
+
# Define a function to process the image with the loaded model
|
162 |
+
def process_image(image_path: str,
|
163 |
+
controlnet_model_name_or_path: str,
|
164 |
+
caption_model_name: str,
|
165 |
+
positive_prompt: Optional[str],
|
166 |
+
negative_prompt: Optional[str],
|
167 |
+
seed: int,
|
168 |
+
num_inference_steps: int,
|
169 |
+
mixed_precision: str,
|
170 |
+
pretrained_model_name_or_path: str,
|
171 |
+
pretrained_vae_model_name_or_path: Optional[str],
|
172 |
+
revision: Optional[str],
|
173 |
+
variant: Optional[str],
|
174 |
+
repo: str,
|
175 |
+
ckpt: str,) -> PIL.Image.Image:
|
176 |
+
# Seed
|
177 |
+
generator = torch.manual_seed(seed)
|
178 |
+
|
179 |
+
# Accelerator Setting
|
180 |
+
accelerator = Accelerator(
|
181 |
+
mixed_precision=mixed_precision,
|
182 |
+
)
|
183 |
+
|
184 |
+
weight_dtype = torch.float32
|
185 |
+
if accelerator.mixed_precision == "fp16":
|
186 |
+
weight_dtype = torch.float16
|
187 |
+
elif accelerator.mixed_precision == "bf16":
|
188 |
+
weight_dtype = torch.bfloat16
|
189 |
+
|
190 |
+
vae_path = (
|
191 |
+
pretrained_model_name_or_path
|
192 |
+
if pretrained_vae_model_name_or_path is None
|
193 |
+
else pretrained_vae_model_name_or_path
|
194 |
+
)
|
195 |
+
vae = AutoencoderKL.from_pretrained(
|
196 |
+
vae_path,
|
197 |
+
subfolder="vae" if pretrained_vae_model_name_or_path is None else None,
|
198 |
+
revision=revision,
|
199 |
+
variant=variant,
|
200 |
+
)
|
201 |
+
unet = UNet2DConditionModel.from_config(
|
202 |
+
pretrained_model_name_or_path,
|
203 |
+
subfolder="unet",
|
204 |
+
revision=revision,
|
205 |
+
variant=variant,
|
206 |
+
)
|
207 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
|
208 |
+
|
209 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
210 |
+
# The VAE is in float32 to avoid NaN losses.
|
211 |
+
if pretrained_vae_model_name_or_path is not None:
|
212 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
213 |
+
else:
|
214 |
+
vae.to(accelerator.device, dtype=torch.float32)
|
215 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
216 |
+
|
217 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype)
|
218 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
219 |
+
pretrained_model_name_or_path,
|
220 |
+
vae=vae,
|
221 |
+
unet=unet,
|
222 |
+
controlnet=controlnet,
|
223 |
+
)
|
224 |
+
pipe.to(accelerator.device, dtype=weight_dtype)
|
225 |
+
|
226 |
+
image = PIL.Image.open(image_path)
|
227 |
+
|
228 |
+
# Prepare everything with our `accelerator`.
|
229 |
+
pipe, image = accelerator.prepare(pipe, image)
|
230 |
+
pipe.safety_checker = None
|
231 |
+
|
232 |
+
# Convert image into grayscale
|
233 |
+
original_size = image.size
|
234 |
+
control_image = image.convert("L").convert("RGB").resize((512, 512))
|
235 |
+
|
236 |
+
# Image captioning
|
237 |
+
if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base":
|
238 |
+
caption = blip_image_captioning(control_image, caption_model_name,
|
239 |
+
weight_dtype, accelerator.device, conditional=True)
|
240 |
+
# elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k":
|
241 |
+
# caption = clip_image_captioning(control_image, caption_model_name, accelerator.device)
|
242 |
+
# elif caption_model_name == "vit-gpt2-image-captioning":
|
243 |
+
# caption = vit_gpt2_image_captioning(control_image, accelerator.device)
|
244 |
+
caption = remove_unlikely_words(caption)
|
245 |
+
|
246 |
+
# Combine positive prompt and captioning result
|
247 |
+
prompt = [positive_prompt + ", " + caption]
|
248 |
+
|
249 |
+
# Image colorization
|
250 |
+
image = pipe(prompt=prompt,
|
251 |
+
negative_prompt=negative_prompt,
|
252 |
+
num_inference_steps=num_inference_steps,
|
253 |
+
generator=generator,
|
254 |
+
image=control_image).images[0]
|
255 |
+
|
256 |
+
# Apply color mapping
|
257 |
+
result_image = apply_color(control_image, image)
|
258 |
+
result_image = result_image.resize(original_size)
|
259 |
+
return result_image, caption
|
260 |
+
|
261 |
+
# Define the image gallery based on folder path
|
262 |
+
def get_image_paths(folder_path):
|
263 |
+
import os
|
264 |
+
image_paths = []
|
265 |
+
for filename in os.listdir(folder_path):
|
266 |
+
if filename.endswith(".jpg") or filename.endswith(".png"):
|
267 |
+
image_paths.append([os.path.join(folder_path, filename)])
|
268 |
+
return image_paths
|
269 |
+
|
270 |
+
# Create the Gradio interface
|
271 |
+
def create_interface():
|
272 |
+
controlnet_model_dict = {
|
273 |
+
"sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet",
|
274 |
+
"sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet",
|
275 |
+
}
|
276 |
+
images = get_image_paths("example/legacy_images") # Replace with your folder path
|
277 |
+
|
278 |
+
interface = gr.Interface(
|
279 |
+
fn=process_image,
|
280 |
+
inputs=[
|
281 |
+
gr.Image(label="Upload image",
|
282 |
+
value="example/legacy_images/Hollywood-Sign.jpg",
|
283 |
+
type='filepath'),
|
284 |
+
gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict],
|
285 |
+
value=controlnet_model_dict["sdxl-light-caption-30000"],
|
286 |
+
label="Select ControlNet Model"),
|
287 |
+
gr.Dropdown(choices=["blip-image-captioning-large",
|
288 |
+
"blip-image-captioning-base",],
|
289 |
+
value="blip-image-captioning-large",
|
290 |
+
label="Select Image Captioning Model"),
|
291 |
+
gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"),
|
292 |
+
gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate",
|
293 |
+
label="Negative Prompt", placeholder="Text for negative prompt"),
|
294 |
+
],
|
295 |
+
outputs=[
|
296 |
+
gr.Image(label="Colorized image",
|
297 |
+
value="example/UUColor_results/Hollywood-Sign.jpeg",
|
298 |
+
format="jpeg"),
|
299 |
+
gr.Textbox(label="Captioning Result", show_copy_button=True)
|
300 |
+
],
|
301 |
+
examples=images,
|
302 |
+
additional_inputs=[
|
303 |
+
# gr.Radio(choices=["Original", "Square"], value="Original",
|
304 |
+
# label="Output resolution"),
|
305 |
+
# gr.Slider(minimum=128, maximum=512, value=256, step=128,
|
306 |
+
# label="Height & Width",
|
307 |
+
# info='Only effect if select "Square" output resolution'),
|
308 |
+
gr.Slider(0, 1000, 123, label="Seed"),
|
309 |
+
gr.Radio(choices=[1, 2, 4, 8],
|
310 |
+
value=8,
|
311 |
+
label="Inference Steps",
|
312 |
+
info="1-step, 2-step, 4-step, or 8-step distilled models"),
|
313 |
+
gr.Radio(choices=["no", "fp16", "bf16"],
|
314 |
+
value="fp16",
|
315 |
+
label="Mixed Precision",
|
316 |
+
info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."),
|
317 |
+
gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"],
|
318 |
+
value="stabilityai/stable-diffusion-xl-base-1.0",
|
319 |
+
label="Base Model",
|
320 |
+
info="Path to pretrained model or model identifier from huggingface.co/models."),
|
321 |
+
gr.Dropdown(choices=["None"],
|
322 |
+
value=None,
|
323 |
+
label="VAE Model",
|
324 |
+
info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."),
|
325 |
+
gr.Dropdown(choices=["None"],
|
326 |
+
value=None,
|
327 |
+
label="Varient",
|
328 |
+
info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"),
|
329 |
+
gr.Dropdown(choices=["None"],
|
330 |
+
value=None,
|
331 |
+
label="Revision",
|
332 |
+
info="Revision of pretrained model identifier from huggingface.co/models."),
|
333 |
+
gr.Dropdown(choices=["ByteDance/SDXL-Lightning"],
|
334 |
+
value="ByteDance/SDXL-Lightning",
|
335 |
+
label="Repository",
|
336 |
+
info="Repository from huggingface.co"),
|
337 |
+
gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors",
|
338 |
+
"sdxl_lightning_2step_unet.safetensors",
|
339 |
+
"sdxl_lightning_4step_unet.safetensors",
|
340 |
+
"sdxl_lightning_8step_unet.safetensors"],
|
341 |
+
value="sdxl_lightning_8step_unet.safetensors",
|
342 |
+
label="Checkpoint",
|
343 |
+
info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"),
|
344 |
+
],
|
345 |
+
title="Text-Guided Image Colorization",
|
346 |
+
description="Upload an image and select a model to colorize it."
|
347 |
+
)
|
348 |
+
return interface
|
349 |
+
|
350 |
+
def main():
|
351 |
+
# Launch the Gradio interface
|
352 |
+
interface = create_interface()
|
353 |
+
interface.launch()
|
354 |
+
|
355 |
+
if __name__ == "__main__":
|
356 |
+
main()
|
images/000000022935_gray.jpg
ADDED
images/000000022935_green_shirt_on_right_girl.jpeg
ADDED
images/000000022935_purple_shirt_on_right_girl.jpeg
ADDED
images/000000022935_red_shirt_on_right_girl.jpeg
ADDED
images/000000025560_color.jpg
ADDED
images/000000025560_gray.jpg
ADDED
images/000000025560_gt.jpg
ADDED
images/000000041633_black_car.jpeg
ADDED
images/000000041633_bright_red_car.jpeg
ADDED
images/000000041633_dark_blue_car.jpeg
ADDED
images/000000041633_gray.jpg
ADDED
images/000000065736_color.jpg
ADDED
images/000000065736_gray.jpg
ADDED
images/000000065736_gt.jpg
ADDED
images/000000091779_color.jpg
ADDED
images/000000091779_gray.jpg
ADDED
images/000000091779_gt.jpg
ADDED
images/000000092177_color.jpg
ADDED
images/000000092177_gray.jpg
ADDED
images/000000092177_gt.jpg
ADDED
images/000000166426_color.jpg
ADDED
images/000000166426_gray.jpg
ADDED
images/000000166426_gt.jpg
ADDED
images/000000286708_gray.jpg
ADDED
images/000000286708_orange_hat.jpeg
ADDED
images/000000286708_pink_hat.jpeg
ADDED
images/000000286708_yellow_hat.jpeg
ADDED