|
import sys |
|
from pathlib import Path |
|
from typing import List, Optional |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler |
|
from huggingface_hub import snapshot_download |
|
from transformers import CLIPTokenizer |
|
|
|
from src import constants |
|
from src.checkpoint_handler import CheckpointHandler |
|
from src.models.neti_clip_text_encoder import NeTICLIPTextModel |
|
from src.models.xti_attention_processor import XTIAttenProc |
|
from src.prompt_manager import PromptManager |
|
from src.scripts.inference import run_inference |
|
|
|
sys.path.append(".") |
|
sys.path.append("..") |
|
|
|
DESCRIPTION = ''' |
|
# A Neural Space-Time Representation for Text-to-Image Personalization |
|
<p style="text-align: center;"> |
|
This is a demo for our <a href="https://arxiv.org/abs/2305.15391">paper</a>: ''A Neural Space-Time Representation |
|
for Text-to-Image Personalization''. |
|
<br> |
|
Project page and code is available <a href="https://neuraltextualinversion.github.io/NeTI/">here</a>. |
|
<br> |
|
We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and |
|
the U-Net layers. |
|
This space-time representation is learned implicitly via a small mapping network. |
|
<br> |
|
Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and |
|
random seed. |
|
<br> |
|
You can also choose different truncation values to play with the reconstruction vs. editability of the concept. |
|
</p> |
|
''' |
|
|
|
CONCEPT_TO_PLACEHOLDER = { |
|
'barn': '<barn>', |
|
'cat': '<cat>', |
|
'clock': '<clock>', |
|
'colorful_teapot': '<colorful-teapot>', |
|
'dangling_child': '<dangling-child>', |
|
'dog': '<dog>', |
|
'elephant': '<elephant>', |
|
'fat_stone_bird': '<stone-bird>', |
|
'headless_statue': '<headless-statue>', |
|
'lecun': '<lecun>', |
|
'maeve': '<maeve-dog>', |
|
'metal_bird': '<metal-bird>', |
|
'mugs_skulls': '<mug-skulls>', |
|
'rainbow_cat': '<rainbow-cat>', |
|
'red_bowl': '<red-bowl>', |
|
'teddybear': '<teddybear>', |
|
'tortoise_plushy': '<tortoise-plushy>', |
|
'wooden_pot': '<wooden-pot>' |
|
} |
|
|
|
MODELS_PATH = Path('./trained_models') |
|
MODELS_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def load_stable_diffusion_model(pretrained_model_name_or_path: str, |
|
num_denoising_steps: int = 50, |
|
torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline: |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="tokenizer") |
|
text_encoder = NeTICLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, |
|
) |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
pretrained_model_name_or_path, |
|
torch_dtype=torch_dtype, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer |
|
).to("cuda") |
|
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device) |
|
pipeline.unet.set_attn_processor(XTIAttenProc()) |
|
return pipeline |
|
|
|
|
|
def get_possible_concepts() -> List[str]: |
|
objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()] |
|
return [x.name for x in objects] |
|
|
|
|
|
def load_sd_and_all_tokens(): |
|
mappers = {} |
|
pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4") |
|
print("Downloading all models from HF Hub...") |
|
snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models') |
|
print("Done.") |
|
concepts = get_possible_concepts() |
|
for concept in concepts: |
|
print(f"Loading model for concept: {concept}") |
|
learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin" |
|
mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt" |
|
train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path) |
|
placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip( |
|
learned_embeds_path=learned_embeds_path, |
|
text_encoder=pipeline.text_encoder, |
|
tokenizer=pipeline.tokenizer |
|
) |
|
mappers[concept] = { |
|
"mapper": mapper, |
|
"placeholder_token": placeholder_token, |
|
"placeholder_token_id": placeholder_token_id |
|
} |
|
return mappers, pipeline |
|
|
|
|
|
mappers, pipeline = load_sd_and_all_tokens() |
|
|
|
|
|
def main_pipeline(concept_name: str, |
|
prompt_input: str, |
|
seed: int, |
|
use_truncation: bool = False, |
|
truncation_idx: Optional[int] = None) -> Image.Image: |
|
pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"]) |
|
placeholder_token = mappers[concept_name]["placeholder_token"] |
|
placeholder_token_id = mappers[concept_name]["placeholder_token_id"] |
|
prompt_manager = PromptManager(tokenizer=pipeline.tokenizer, |
|
text_encoder=pipeline.text_encoder, |
|
timesteps=pipeline.scheduler.timesteps, |
|
unet_layers=constants.UNET_LAYERS, |
|
placeholder_token=placeholder_token, |
|
placeholder_token_id=placeholder_token_id, |
|
torch_dtype=torch.float16) |
|
image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]), |
|
pipeline=pipeline, |
|
prompt_manager=prompt_manager, |
|
seeds=[int(seed)], |
|
num_images_per_prompt=1, |
|
truncation_idx=truncation_idx if use_truncation else None) |
|
return [image] |
|
|
|
|
|
with gr.Blocks(css='style.css') as demo: |
|
gr.Markdown(DESCRIPTION) |
|
|
|
gr.HTML('''<a href="https://huggingface.co/spaces/neural-ti/NeTI?duplicate=true"><img src="https://bit.ly/3gLdBN6" |
|
alt="Duplicate Space"></a>''') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept", |
|
info="Choose your concept") |
|
prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. " |
|
"Please use * to specify the concept.") |
|
random_seed = gr.Number(value=42, label="Random seed", precision=0) |
|
use_truncation = gr.Checkbox(label="Use inference-time dropout", |
|
info="Whether to use our dropout technique when computing the concept " |
|
"embeddings.") |
|
truncation_idx = gr.Slider(8, 128, label="Truncation index", |
|
info="If using truncation, which index to truncate from. Lower numbers tend to " |
|
"result in more editable images, but at the cost of reconstruction.") |
|
run_button = gr.Button('Generate') |
|
|
|
with gr.Column(): |
|
result = gr.Gallery(label='Result') |
|
inputs = [concept, prompt, random_seed, use_truncation, truncation_idx] |
|
outputs = [result] |
|
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs) |
|
|
|
with gr.Row(): |
|
examples = [ |
|
["maeve", "A photo of * swimming in the ocean", 5196, True, 16], |
|
["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8], |
|
["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32], |
|
["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8], |
|
["metal_bird", "* in a comic book", 1028, True, 24], |
|
["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True, |
|
64], |
|
] |
|
gr.Examples(examples=examples, |
|
inputs=[concept, prompt, random_seed, use_truncation, truncation_idx], |
|
outputs=[result], |
|
fn=main_pipeline, |
|
cache_examples=True) |
|
|
|
demo.queue(max_size=50).launch(share=False) |
|
|