import sys
from pathlib import Path
from typing import List, Optional
import gradio as gr
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from huggingface_hub import snapshot_download
from transformers import CLIPTokenizer
from src import constants
from src.checkpoint_handler import CheckpointHandler
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
from src.models.xti_attention_processor import XTIAttenProc
from src.prompt_manager import PromptManager
from src.scripts.inference import run_inference
sys.path.append(".")
sys.path.append("..")
DESCRIPTION = '''
# A Neural Space-Time Representation for Text-to-Image Personalization
This is a demo for our paper: ''A Neural Space-Time Representation
for Text-to-Image Personalization''.
Project page and code is available here.
We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and
the U-Net layers.
This space-time representation is learned implicitly via a small mapping network.
Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and
random seed.
You can also choose different truncation values to play with the reconstruction vs. editability of the concept.
'''
CONCEPT_TO_PLACEHOLDER = {
'barn': '',
'cat': '',
'clock': '',
'colorful_teapot': '',
'dangling_child': '',
'dog': '',
'elephant': '',
'fat_stone_bird': '',
'headless_statue': '',
'lecun': '',
'maeve': '',
'metal_bird': '',
'mugs_skulls': '',
'rainbow_cat': '',
'red_bowl': '',
'teddybear': '',
'tortoise_plushy': '',
'wooden_pot': ''
}
MODELS_PATH = Path('./trained_models')
MODELS_PATH.mkdir(parents=True, exist_ok=True)
def load_stable_diffusion_model(pretrained_model_name_or_path: str,
num_denoising_steps: int = 50,
torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline:
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = NeTICLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
text_encoder=text_encoder,
tokenizer=tokenizer
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
pipeline.unet.set_attn_processor(XTIAttenProc())
return pipeline
def get_possible_concepts() -> List[str]:
objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()]
return [x.name for x in objects]
def load_sd_and_all_tokens():
mappers = {}
pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4")
print("Downloading all models from HF Hub...")
snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models')
print("Done.")
concepts = get_possible_concepts()
for concept in concepts:
print(f"Loading model for concept: {concept}")
learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin"
mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt"
train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path)
placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
learned_embeds_path=learned_embeds_path,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer
)
mappers[concept] = {
"mapper": mapper,
"placeholder_token": placeholder_token,
"placeholder_token_id": placeholder_token_id
}
return mappers, pipeline
mappers, pipeline = load_sd_and_all_tokens()
def main_pipeline(concept_name: str,
prompt_input: str,
seed: int,
use_truncation: bool = False,
truncation_idx: Optional[int] = None) -> Image.Image:
pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"])
placeholder_token = mappers[concept_name]["placeholder_token"]
placeholder_token_id = mappers[concept_name]["placeholder_token_id"]
prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
text_encoder=pipeline.text_encoder,
timesteps=pipeline.scheduler.timesteps,
unet_layers=constants.UNET_LAYERS,
placeholder_token=placeholder_token,
placeholder_token_id=placeholder_token_id,
torch_dtype=torch.float16)
image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]),
pipeline=pipeline,
prompt_manager=prompt_manager,
seeds=[int(seed)],
num_images_per_prompt=1,
truncation_idx=truncation_idx if use_truncation else None)
return [image]
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
gr.HTML('''''')
with gr.Row():
with gr.Column():
concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept",
info="Choose your concept")
prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. "
"Please use * to specify the concept.")
random_seed = gr.Number(value=42, label="Random seed", precision=0)
use_truncation = gr.Checkbox(label="Use inference-time dropout",
info="Whether to use our dropout technique when computing the concept "
"embeddings.")
truncation_idx = gr.Slider(8, 128, label="Truncation index",
info="If using truncation, which index to truncate from. Lower numbers tend to "
"result in more editable images, but at the cost of reconstruction.")
run_button = gr.Button('Generate')
with gr.Column():
result = gr.Gallery(label='Result')
inputs = [concept, prompt, random_seed, use_truncation, truncation_idx]
outputs = [result]
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
with gr.Row():
examples = [
["maeve", "A photo of * swimming in the ocean", 5196, True, 16],
["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8],
["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32],
["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8],
["metal_bird", "* in a comic book", 1028, True, 24],
["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True,
64],
]
gr.Examples(examples=examples,
inputs=[concept, prompt, random_seed, use_truncation, truncation_idx],
outputs=[result],
fn=main_pipeline,
cache_examples=True)
demo.queue(max_size=50).launch(share=False)