Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import PIL | |
import torch | |
import subprocess | |
import gradio as gr | |
import os | |
from typing import Optional | |
from accelerate import Accelerator | |
from diffusers import ( | |
AutoencoderKL, | |
StableDiffusionXLControlNetPipeline, | |
ControlNetModel, | |
UNet2DConditionModel, | |
) | |
from transformers import ( | |
BlipProcessor, BlipForConditionalGeneration, | |
VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
) | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from clip_interrogator import Interrogator, Config, list_clip_models | |
from huggingface_hub import snapshot_download | |
# Download colorization models | |
os.makedirs("sdxl_light_caption_output", exist_ok=True) | |
os.makedirs("sdxl_light_custom_caption_output", exist_ok=True) | |
snapshot_download( | |
repo_id = 'nickpai/sdxl_light_caption_output', | |
local_dir = 'sdxl_light_caption_output' | |
) | |
snapshot_download( | |
repo_id = 'nickpai/sdxl_light_custom_caption_output', | |
local_dir = 'sdxl_light_custom_caption_output' | |
) | |
def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image: | |
# Convert input images to LAB color space | |
image_lab = image.convert('LAB') | |
color_map_lab = color_map.convert('LAB') | |
# Split LAB channels | |
l, a , b = image_lab.split() | |
_, a_map, b_map = color_map_lab.split() | |
# Merge LAB channels with color map | |
merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map)) | |
# Convert merged LAB image back to RGB color space | |
result_rgb = merged_lab.convert('RGB') | |
return result_rgb | |
def remove_unlikely_words(prompt: str) -> str: | |
""" | |
Removes unlikely words from a prompt. | |
Args: | |
prompt: The text prompt to be cleaned. | |
Returns: | |
The cleaned prompt with unlikely words removed. | |
""" | |
unlikely_words = [] | |
a1_list = [f'{i}s' for i in range(1900, 2000)] | |
a2_list = [f'{i}' for i in range(1900, 2000)] | |
a3_list = [f'year {i}' for i in range(1900, 2000)] | |
a4_list = [f'circa {i}' for i in range(1900, 2000)] | |
b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list] | |
b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] | |
b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] | |
b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list] | |
words_list = [ | |
"black and white,", "black and white", "black & white,", "black & white", "circa", | |
"balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,", | |
"black - and - white photography,", "monochrome bw,", "black white,", "black an white,", | |
"grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo", | |
"back and white", "back and white,", "monochrome contrast", "monochrome", "grainy", | |
"grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w", | |
"grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo", | |
"b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,", | |
"black-and-white photo,", "black-and-white photo", "black - and - white photography", | |
"b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic", | |
"blurry photo,", "blurry,", "blurry photography,", "monochromatic photo", | |
"black - and - white photograph,", "black - and - white photograph", "black on white,", | |
"black on white", "black-and-white", "historical image,", "historical picture,", | |
"historical photo,", "historical photograph,", "archival photo,", "taken in the early", | |
"taken in the late", "taken in the", "historic photograph,", "restored,", "restored", | |
"historical photo", "historical setting,", | |
"historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated", | |
"taken in", "shot on leica", "shot on leica sl2", "sl2", | |
"taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting", | |
"overcast day", "overcast weather", "slight overcast", "overcast", | |
"picture taken in", "photo taken in", | |
", photo", ", photo", ", photo", ", photo", ", photograph", | |
",,", ",,,", ",,,,", " ,", " ,", " ,", " ,", | |
] | |
unlikely_words.extend(a1_list) | |
unlikely_words.extend(a2_list) | |
unlikely_words.extend(a3_list) | |
unlikely_words.extend(a4_list) | |
unlikely_words.extend(b1_list) | |
unlikely_words.extend(b2_list) | |
unlikely_words.extend(b3_list) | |
unlikely_words.extend(b4_list) | |
unlikely_words.extend(words_list) | |
for word in unlikely_words: | |
prompt = prompt.replace(word, "") | |
return prompt | |
def blip_image_captioning(image: PIL.Image.Image, | |
model_backbone: str, | |
weight_dtype: type, | |
device: str, | |
conditional: bool) -> str: | |
# https://huggingface.co/Salesforce/blip-image-captioning-large | |
# https://huggingface.co/Salesforce/blip-image-captioning-base | |
if weight_dtype == torch.bfloat16: # in case model might not accept bfloat16 data type | |
weight_dtype = torch.float16 | |
processor = BlipProcessor.from_pretrained(f"Salesforce/{model_backbone}") | |
model = BlipForConditionalGeneration.from_pretrained( | |
f"Salesforce/{model_backbone}", torch_dtype=weight_dtype).to(device) | |
valid_backbones = ["blip-image-captioning-large", "blip-image-captioning-base"] | |
if model_backbone not in valid_backbones: | |
raise ValueError(f"Invalid model backbone '{model_backbone}'. \ | |
Valid options are: {', '.join(valid_backbones)}") | |
if conditional: | |
text = "a photography of" | |
inputs = processor(image, text, return_tensors="pt").to(device, weight_dtype) | |
else: | |
inputs = processor(image, return_tensors="pt").to(device) | |
out = model.generate(**inputs) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
# def vit_gpt2_image_captioning(image: PIL.Image.Image, device: str) -> str: | |
# # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning | |
# model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device) | |
# feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
# tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
# max_length = 16 | |
# num_beams = 4 | |
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
# pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
# pixel_values = pixel_values.to(device) | |
# output_ids = model.generate(pixel_values, **gen_kwargs) | |
# preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
# caption = [pred.strip() for pred in preds] | |
# return caption[0] | |
# def clip_image_captioning(image: PIL.Image.Image, | |
# clip_model_name: str, | |
# device: str) -> str: | |
# # validate clip model name | |
# models = list_clip_models() | |
# if clip_model_name not in models: | |
# raise ValueError(f"Could not find CLIP model {clip_model_name}! \ | |
# Available models: {models}") | |
# config = Config(device=device, clip_model_name=clip_model_name) | |
# config.apply_low_vram_defaults() | |
# ci = Interrogator(config) | |
# caption = ci.interrogate(image) | |
# return caption | |
# Define a function to process the image with the loaded model | |
def process_image(image_path: str, | |
controlnet_model_name_or_path: str, | |
caption_model_name: str, | |
positive_prompt: Optional[str], | |
negative_prompt: Optional[str], | |
seed: int, | |
num_inference_steps: int, | |
mixed_precision: str, | |
pretrained_model_name_or_path: str, | |
pretrained_vae_model_name_or_path: Optional[str], | |
revision: Optional[str], | |
variant: Optional[str], | |
repo: str, | |
ckpt: str,) -> PIL.Image.Image: | |
# Seed | |
generator = torch.manual_seed(seed) | |
# Accelerator Setting | |
accelerator = Accelerator( | |
mixed_precision=mixed_precision, | |
) | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
vae_path = ( | |
pretrained_model_name_or_path | |
if pretrained_vae_model_name_or_path is None | |
else pretrained_vae_model_name_or_path | |
) | |
vae = AutoencoderKL.from_pretrained( | |
vae_path, | |
subfolder="vae" if pretrained_vae_model_name_or_path is None else None, | |
revision=revision, | |
variant=variant, | |
) | |
unet = UNet2DConditionModel.from_config( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
revision=revision, | |
variant=variant, | |
) | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) | |
# Move vae, unet and text_encoder to device and cast to weight_dtype | |
# The VAE is in float32 to avoid NaN losses. | |
if pretrained_vae_model_name_or_path is not None: | |
vae.to(accelerator.device, dtype=weight_dtype) | |
else: | |
vae.to(accelerator.device, dtype=torch.float32) | |
unet.to(accelerator.device, dtype=weight_dtype) | |
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path, torch_dtype=weight_dtype) | |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
pretrained_model_name_or_path, | |
vae=vae, | |
unet=unet, | |
controlnet=controlnet, | |
) | |
pipe.to(accelerator.device, dtype=weight_dtype) | |
image = PIL.Image.open(image_path) | |
# Prepare everything with our `accelerator`. | |
pipe, image = accelerator.prepare(pipe, image) | |
pipe.safety_checker = None | |
# Convert image into grayscale | |
original_size = image.size | |
control_image = image.convert("L").convert("RGB").resize((512, 512)) | |
# Image captioning | |
if caption_model_name == "blip-image-captioning-large" or "blip-image-captioning-base": | |
caption = blip_image_captioning(control_image, caption_model_name, | |
weight_dtype, accelerator.device, conditional=True) | |
# elif caption_model_name == "ViT-L-14/openai" or "ViT-H-14/laion2b_s32b_b79k": | |
# caption = clip_image_captioning(control_image, caption_model_name, accelerator.device) | |
# elif caption_model_name == "vit-gpt2-image-captioning": | |
# caption = vit_gpt2_image_captioning(control_image, accelerator.device) | |
caption = remove_unlikely_words(caption) | |
# Combine positive prompt and captioning result | |
prompt = [positive_prompt + ", " + caption] | |
# Image colorization | |
image = pipe(prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
image=control_image).images[0] | |
# Apply color mapping | |
result_image = apply_color(control_image, image) | |
result_image = result_image.resize(original_size) | |
return result_image, caption | |
# Define the image gallery based on folder path | |
def get_image_paths(folder_path): | |
import os | |
image_paths = [] | |
for filename in os.listdir(folder_path): | |
if filename.endswith(".jpg") or filename.endswith(".png"): | |
image_paths.append([os.path.join(folder_path, filename)]) | |
return image_paths | |
# Create the Gradio interface | |
def create_interface(): | |
controlnet_model_dict = { | |
"sdxl-light-caption-30000": "sdxl_light_caption_output/checkpoint-30000/controlnet", | |
"sdxl-light-custom-caption-30000": "sdxl_light_custom_caption_output/checkpoint-30000/controlnet", | |
} | |
images = get_image_paths("example/legacy_images") # Replace with your folder path | |
interface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(label="Upload image", | |
value="example/legacy_images/Hollywood-Sign.jpg", | |
type='filepath'), | |
gr.Dropdown(choices=[controlnet_model_dict[key] for key in controlnet_model_dict], | |
value=controlnet_model_dict["sdxl-light-caption-30000"], | |
label="Select ControlNet Model"), | |
gr.Dropdown(choices=["blip-image-captioning-large", | |
"blip-image-captioning-base",], | |
value="blip-image-captioning-large", | |
label="Select Image Captioning Model"), | |
gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt"), | |
gr.Textbox(value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate", | |
label="Negative Prompt", placeholder="Text for negative prompt"), | |
], | |
outputs=[ | |
gr.Image(label="Colorized image", | |
value="example/UUColor_results/Hollywood-Sign.jpeg", | |
format="jpeg"), | |
gr.Textbox(label="Captioning Result", show_copy_button=True) | |
], | |
examples=images, | |
additional_inputs=[ | |
# gr.Radio(choices=["Original", "Square"], value="Original", | |
# label="Output resolution"), | |
# gr.Slider(minimum=128, maximum=512, value=256, step=128, | |
# label="Height & Width", | |
# info='Only effect if select "Square" output resolution'), | |
gr.Slider(0, 1000, 123, label="Seed"), | |
gr.Radio(choices=[1, 2, 4, 8], | |
value=8, | |
label="Inference Steps", | |
info="1-step, 2-step, 4-step, or 8-step distilled models"), | |
gr.Radio(choices=["no", "fp16", "bf16"], | |
value="fp16", | |
label="Mixed Precision", | |
info="Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16)."), | |
gr.Dropdown(choices=["stabilityai/stable-diffusion-xl-base-1.0"], | |
value="stabilityai/stable-diffusion-xl-base-1.0", | |
label="Base Model", | |
info="Path to pretrained model or model identifier from huggingface.co/models."), | |
gr.Dropdown(choices=["None"], | |
value=None, | |
label="VAE Model", | |
info="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038."), | |
gr.Dropdown(choices=["None"], | |
value=None, | |
label="Varient", | |
info="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16"), | |
gr.Dropdown(choices=["None"], | |
value=None, | |
label="Revision", | |
info="Revision of pretrained model identifier from huggingface.co/models."), | |
gr.Dropdown(choices=["ByteDance/SDXL-Lightning"], | |
value="ByteDance/SDXL-Lightning", | |
label="Repository", | |
info="Repository from huggingface.co"), | |
gr.Dropdown(choices=["sdxl_lightning_1step_unet.safetensors", | |
"sdxl_lightning_2step_unet.safetensors", | |
"sdxl_lightning_4step_unet.safetensors", | |
"sdxl_lightning_8step_unet.safetensors"], | |
value="sdxl_lightning_8step_unet.safetensors", | |
label="Checkpoint", | |
info="Available checkpoints from the repository. Caution! Checkpoint's 'N'step must match with inference steps"), | |
], | |
title="Text-Guided Image Colorization", | |
description="Upload an image and select a model to colorize it.", | |
cache_examples=False | |
) | |
return interface | |
def main(): | |
# Launch the Gradio interface | |
interface = create_interface() | |
interface.launch(ssr_mode=False) | |
if __name__ == "__main__": | |
main() | |