Add demo
Browse files- README.md +3 -3
- gradio_app.py +186 -0
- requirements.txt +10 -0
- src/__init__.py +0 -0
- src/checkpoint_handler.py +107 -0
- src/config.py +146 -0
- src/constants.py +83 -0
- src/models/__init__.py +0 -0
- src/models/net_clip_text_embedding.py +60 -0
- src/models/neti_clip_text_encoder.py +160 -0
- src/models/neti_mapper.py +90 -0
- src/models/positional_encoding.py +57 -0
- src/models/xti_attention_processor.py +57 -0
- src/prompt_manager.py +63 -0
- src/scripts/__init__.py +0 -0
- src/scripts/inference.py +170 -0
- src/sd_pipeline_call.py +146 -0
- src/utils/__init__.py +0 -0
- src/utils/types.py +20 -0
- src/utils/vis_utils.py +17 -0
- style.css +3 -0
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
title: NeTI
|
3 |
emoji: π
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.32.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
1 |
---
|
2 |
title: NeTI
|
3 |
emoji: π
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.32.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
gradio_app.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
9 |
+
from huggingface_hub import snapshot_download
|
10 |
+
from transformers import CLIPTokenizer
|
11 |
+
|
12 |
+
from src import constants
|
13 |
+
from src.checkpoint_handler import CheckpointHandler
|
14 |
+
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
|
15 |
+
from src.models.xti_attention_processor import XTIAttenProc
|
16 |
+
from src.prompt_manager import PromptManager
|
17 |
+
from src.scripts.inference import run_inference
|
18 |
+
|
19 |
+
sys.path.append(".")
|
20 |
+
sys.path.append("..")
|
21 |
+
|
22 |
+
DESCRIPTION = '''
|
23 |
+
# A Neural Space-Time Representation for Text-to-Image Personalization
|
24 |
+
<p style="text-align: center;">
|
25 |
+
This is a demo for our <a href="https://arxiv.org/abs/2305.15391">paper</a>: ''A Neural Space-Time Representation
|
26 |
+
for Text-to-Image Personalization''.
|
27 |
+
<br>
|
28 |
+
Project page and code is available <a href="https://neuraltextualinversion.github.io/NeTI/">here</a>.
|
29 |
+
<br>
|
30 |
+
We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and
|
31 |
+
the U-Net layers.
|
32 |
+
This space-time representation is learned implicitly via a small mapping network.
|
33 |
+
<br>
|
34 |
+
Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and
|
35 |
+
random seed.
|
36 |
+
<br>
|
37 |
+
You can also choose different truncation values to play with the reconstruction vs. editability of the concept.
|
38 |
+
</p>
|
39 |
+
'''
|
40 |
+
|
41 |
+
CONCEPT_TO_PLACEHOLDER = {
|
42 |
+
'barn': '<barn>',
|
43 |
+
'cat': '<cat>',
|
44 |
+
'clock': '<clock>',
|
45 |
+
'colorful_teapot': '<colorful-teapot>',
|
46 |
+
'dangling_child': '<dangling-child>',
|
47 |
+
'dog': '<dog>',
|
48 |
+
'elephant': '<elephant>',
|
49 |
+
'fat_stone_bird': '<stone-bird>',
|
50 |
+
'headless_statue': '<headless-statue>',
|
51 |
+
'lecun': '<lecun>',
|
52 |
+
'maeve': '<maeve-dog>',
|
53 |
+
'metal_bird': '<metal-bird>',
|
54 |
+
'mugs_skulls': '<mug-skulls>',
|
55 |
+
'rainbow_cat': '<rainbow-cat>',
|
56 |
+
'red_bowl': '<red-bowl>',
|
57 |
+
'teddybear': '<teddybear>',
|
58 |
+
'tortoise_plushy': '<tortoise-plushy>',
|
59 |
+
'wooden_pot': '<wooden-pot>'
|
60 |
+
}
|
61 |
+
|
62 |
+
MODELS_PATH = Path('./trained_models')
|
63 |
+
MODELS_PATH.mkdir(parents=True, exist_ok=True)
|
64 |
+
|
65 |
+
|
66 |
+
def load_stable_diffusion_model(pretrained_model_name_or_path: str,
|
67 |
+
num_denoising_steps: int = 50,
|
68 |
+
torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline:
|
69 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
70 |
+
pretrained_model_name_or_path, subfolder="tokenizer")
|
71 |
+
text_encoder = NeTICLIPTextModel.from_pretrained(
|
72 |
+
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
|
73 |
+
)
|
74 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
75 |
+
pretrained_model_name_or_path,
|
76 |
+
torch_dtype=torch_dtype,
|
77 |
+
text_encoder=text_encoder,
|
78 |
+
tokenizer=tokenizer
|
79 |
+
).to("cuda")
|
80 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
81 |
+
pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
|
82 |
+
pipeline.unet.set_attn_processor(XTIAttenProc())
|
83 |
+
return pipeline
|
84 |
+
|
85 |
+
|
86 |
+
def get_possible_concepts() -> List[str]:
|
87 |
+
objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()]
|
88 |
+
return [x.name for x in objects]
|
89 |
+
|
90 |
+
|
91 |
+
def load_sd_and_all_tokens():
|
92 |
+
mappers = {}
|
93 |
+
pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4")
|
94 |
+
print("Downloading all models from HF Hub...")
|
95 |
+
snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models')
|
96 |
+
print("Done.")
|
97 |
+
concepts = get_possible_concepts()
|
98 |
+
for concept in concepts:
|
99 |
+
print(f"Loading model for concept: {concept}")
|
100 |
+
learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin"
|
101 |
+
mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt"
|
102 |
+
train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path)
|
103 |
+
placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
|
104 |
+
learned_embeds_path=learned_embeds_path,
|
105 |
+
text_encoder=pipeline.text_encoder,
|
106 |
+
tokenizer=pipeline.tokenizer
|
107 |
+
)
|
108 |
+
mappers[concept] = {
|
109 |
+
"mapper": mapper,
|
110 |
+
"placeholder_token": placeholder_token,
|
111 |
+
"placeholder_token_id": placeholder_token_id
|
112 |
+
}
|
113 |
+
return mappers, pipeline
|
114 |
+
|
115 |
+
|
116 |
+
mappers, pipeline = load_sd_and_all_tokens()
|
117 |
+
|
118 |
+
|
119 |
+
def main_pipeline(concept_name: str,
|
120 |
+
prompt_input: str,
|
121 |
+
seed: int,
|
122 |
+
use_truncation: bool = False,
|
123 |
+
truncation_idx: Optional[int] = None) -> Image.Image:
|
124 |
+
pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"])
|
125 |
+
placeholder_token = mappers[concept_name]["placeholder_token"]
|
126 |
+
placeholder_token_id = mappers[concept_name]["placeholder_token_id"]
|
127 |
+
prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
|
128 |
+
text_encoder=pipeline.text_encoder,
|
129 |
+
timesteps=pipeline.scheduler.timesteps,
|
130 |
+
unet_layers=constants.UNET_LAYERS,
|
131 |
+
placeholder_token=placeholder_token,
|
132 |
+
placeholder_token_id=placeholder_token_id,
|
133 |
+
torch_dtype=torch.float16)
|
134 |
+
image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]),
|
135 |
+
pipeline=pipeline,
|
136 |
+
prompt_manager=prompt_manager,
|
137 |
+
seeds=[int(seed)],
|
138 |
+
num_images_per_prompt=1,
|
139 |
+
truncation_idx=truncation_idx if use_truncation else None)
|
140 |
+
return [image]
|
141 |
+
|
142 |
+
|
143 |
+
with gr.Blocks(css='style.css') as demo:
|
144 |
+
gr.Markdown(DESCRIPTION)
|
145 |
+
|
146 |
+
gr.HTML('''<a href="https://huggingface.co/spaces/neural-ti/NeTI?duplicate=true"><img src="https://bit.ly/3gLdBN6"
|
147 |
+
alt="Duplicate Space"></a>''')
|
148 |
+
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept",
|
152 |
+
info="Choose your concept")
|
153 |
+
prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. "
|
154 |
+
"Please use * to specify the concept.")
|
155 |
+
random_seed = gr.Number(value=42, label="Random seed", precision=0)
|
156 |
+
use_truncation = gr.Checkbox(label="Use inference-time dropout",
|
157 |
+
info="Whether to use our dropout technique when computing the concept "
|
158 |
+
"embeddings.")
|
159 |
+
truncation_idx = gr.Slider(8, 128, label="Truncation index",
|
160 |
+
info="If using truncation, which index to truncate from. Lower numbers tend to "
|
161 |
+
"result in more editable images, but at the cost of reconstruction.")
|
162 |
+
run_button = gr.Button('Generate')
|
163 |
+
|
164 |
+
with gr.Column():
|
165 |
+
result = gr.Gallery(label='Result')
|
166 |
+
inputs = [concept, prompt, random_seed, use_truncation, truncation_idx]
|
167 |
+
outputs = [result]
|
168 |
+
run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
examples = [
|
172 |
+
["maeve", "A photo of * swimming in the ocean", 5196, True, 16],
|
173 |
+
["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8],
|
174 |
+
["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32],
|
175 |
+
["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8],
|
176 |
+
["metal_bird", "* in a comic book", 1028, True, 24],
|
177 |
+
["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True,
|
178 |
+
64],
|
179 |
+
]
|
180 |
+
gr.Examples(examples=examples,
|
181 |
+
inputs=[concept, prompt, random_seed, use_truncation, truncation_idx],
|
182 |
+
outputs=[result],
|
183 |
+
fn=main_pipeline,
|
184 |
+
cache_examples=True)
|
185 |
+
|
186 |
+
demo.queue(max_size=50).launch(share=False)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.7.0.72
|
2 |
+
matplotlib
|
3 |
+
pyrallis==0.3.1
|
4 |
+
loguru==0.7.0
|
5 |
+
torch==1.13.1
|
6 |
+
torchvision==0.14.1
|
7 |
+
diffusers==0.14.0
|
8 |
+
transformers==4.27.4
|
9 |
+
accelerate==0.18.0
|
10 |
+
gradio
|
src/__init__.py
ADDED
File without changes
|
src/checkpoint_handler.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import pyrallis
|
5 |
+
import torch
|
6 |
+
from accelerate import Accelerator
|
7 |
+
from torch import nn
|
8 |
+
from transformers import CLIPTokenizer
|
9 |
+
|
10 |
+
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
|
11 |
+
from src.models.neti_mapper import NeTIMapper
|
12 |
+
from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
|
13 |
+
from src.config import RunConfig
|
14 |
+
|
15 |
+
|
16 |
+
class CheckpointHandler:
|
17 |
+
|
18 |
+
def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path):
|
19 |
+
self.cfg = cfg
|
20 |
+
self.placeholder_token_string = placeholder_token_string
|
21 |
+
self.placeholder_token_id = placeholder_token_id
|
22 |
+
self.save_root = save_root
|
23 |
+
|
24 |
+
def save_model(self, text_encoder: NeTICLIPTextModel,
|
25 |
+
accelerator: Accelerator,
|
26 |
+
embeds_save_name: str,
|
27 |
+
mapper_save_name: str):
|
28 |
+
self.save_learned_embeds(text_encoder, accelerator, embeds_save_name)
|
29 |
+
self.save_mapper(text_encoder, mapper_save_name)
|
30 |
+
|
31 |
+
def save_learned_embeds(self, text_encoder: NeTICLIPTextModel, accelerator: Accelerator, save_name: str):
|
32 |
+
"""
|
33 |
+
Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference
|
34 |
+
to take the place of our placeholder token.
|
35 |
+
"""
|
36 |
+
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id]
|
37 |
+
learned_embeds = learned_embeds.detach().cpu()
|
38 |
+
learned_embeds_dict = {self.placeholder_token_string: learned_embeds}
|
39 |
+
torch.save(learned_embeds_dict, self.save_root / save_name)
|
40 |
+
|
41 |
+
def save_mapper(self, text_encoder: NeTICLIPTextModel, save_name: str):
|
42 |
+
""" Save the mapper and config to be used at inference. """
|
43 |
+
cfg_ = RunConfig(**self.cfg.__dict__.copy())
|
44 |
+
state_dict = {
|
45 |
+
"state_dict": text_encoder.text_model.embeddings.mapper.state_dict(),
|
46 |
+
"cfg": pyrallis.encode(cfg_),
|
47 |
+
"encoder": text_encoder.text_model.embeddings.mapper.encoder
|
48 |
+
}
|
49 |
+
torch.save(state_dict, self.save_root / save_name)
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def load_mapper(mapper_path: Path) -> Tuple[RunConfig, NeTIMapper]:
|
53 |
+
mapper_ckpt = torch.load(mapper_path, map_location="cpu")
|
54 |
+
cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg'])
|
55 |
+
neti_mapper = NeTIMapper(output_dim=768,
|
56 |
+
use_nested_dropout=cfg.model.use_nested_dropout,
|
57 |
+
nested_dropout_prob=cfg.model.nested_dropout_prob,
|
58 |
+
norm_scale=cfg.model.target_norm,
|
59 |
+
use_positional_encoding=cfg.model.use_positional_encoding,
|
60 |
+
num_pe_time_anchors=cfg.model.num_pe_time_anchors,
|
61 |
+
pe_sigmas=cfg.model.pe_sigmas,
|
62 |
+
output_bypass=cfg.model.output_bypass)
|
63 |
+
neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True)
|
64 |
+
encoder = mapper_ckpt['encoder']
|
65 |
+
if isinstance(encoder, NeTIPositionalEncoding):
|
66 |
+
encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda())
|
67 |
+
elif isinstance(encoder, BasicEncoder):
|
68 |
+
encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda()
|
69 |
+
encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda()
|
70 |
+
neti_mapper.encoder = encoder.cuda()
|
71 |
+
neti_mapper.cuda()
|
72 |
+
neti_mapper.eval()
|
73 |
+
return cfg, neti_mapper
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def load_learned_embed_in_clip(learned_embeds_path: Path,
|
77 |
+
text_encoder: NeTICLIPTextModel,
|
78 |
+
tokenizer: CLIPTokenizer) -> Tuple[str, int]:
|
79 |
+
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
80 |
+
|
81 |
+
# separate token and the embeds
|
82 |
+
trained_tokens = list(loaded_learned_embeds.keys())
|
83 |
+
embeds = list(loaded_learned_embeds.values())
|
84 |
+
|
85 |
+
# cast to dtype of text_encoder
|
86 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
87 |
+
embeds = [e.to(dtype) for e in embeds]
|
88 |
+
|
89 |
+
# add the tokens in tokenizer
|
90 |
+
num_added_tokens = tokenizer.add_tokens(trained_tokens)
|
91 |
+
if num_added_tokens == 0:
|
92 |
+
raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. "
|
93 |
+
f"Please pass a different `token` that is not already in the tokenizer.")
|
94 |
+
|
95 |
+
# resize the token embeddings
|
96 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
97 |
+
|
98 |
+
# get the id for the token and assign the embeds
|
99 |
+
placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens]
|
100 |
+
|
101 |
+
for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)):
|
102 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embed
|
103 |
+
|
104 |
+
assert len(trained_tokens) == 1, "Only one placeholder token is supported"
|
105 |
+
placeholder_token = trained_tokens[0]
|
106 |
+
placeholder_token_id = placeholder_token_ids[0]
|
107 |
+
return placeholder_token, placeholder_token_id
|
src/config.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Optional, Dict
|
4 |
+
|
5 |
+
from constants import VALIDATION_PROMPTS
|
6 |
+
from utils.types import PESigmas
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class LogConfig:
|
11 |
+
""" Parameters for logging and saving """
|
12 |
+
# Name of experiment. This will be the name of the output folder
|
13 |
+
exp_name: str
|
14 |
+
# The output directory where the model predictions and checkpoints will be written
|
15 |
+
exp_dir: Path = Path("./outputs")
|
16 |
+
# Save interval
|
17 |
+
save_steps: int = 250
|
18 |
+
# [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
|
19 |
+
# `output_dir/runs/**CURRENT_DATETIME_HOSTNAME`
|
20 |
+
logging_dir: Path = Path("logs")
|
21 |
+
# The integration to report the results to. Supported platforms are "tensorboard" '
|
22 |
+
# (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
23 |
+
report_to: str = "tensorboard"
|
24 |
+
# Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator`
|
25 |
+
checkpoints_total_limit: Optional[int] = None
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class DataConfig:
|
30 |
+
""" Parameters for data """
|
31 |
+
# A folder containing the training data
|
32 |
+
train_data_dir: Path
|
33 |
+
# A token to use as a placeholder for the concept
|
34 |
+
placeholder_token: str
|
35 |
+
# Super category token to use for normalizing the mapper output
|
36 |
+
super_category_token: Optional[str] = "object"
|
37 |
+
# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process
|
38 |
+
dataloader_num_workers: int = 8
|
39 |
+
# Choose between 'object' and 'style' - used for selecting the prompts for training
|
40 |
+
learnable_property: str = "object"
|
41 |
+
# How many times to repeat the training data
|
42 |
+
repeats: int = 100
|
43 |
+
# The resolution for input images, all the images in the train/validation dataset will be resized to this resolution
|
44 |
+
resolution: int = 512
|
45 |
+
# Whether to center crop images before resizing to resolution
|
46 |
+
center_crop: bool = False
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class ModelConfig:
|
51 |
+
""" Parameters for defining all models """
|
52 |
+
# Path to pretrained model or model identifier from huggingface.co/models
|
53 |
+
pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4"
|
54 |
+
# Whether to use our Nested Dropout technique
|
55 |
+
use_nested_dropout: bool = True
|
56 |
+
# Probability to apply nested dropout during training
|
57 |
+
nested_dropout_prob: float = 0.5
|
58 |
+
# Whether to normalize the norm of the mapper's output vector
|
59 |
+
normalize_mapper_output: bool = True
|
60 |
+
# Target norm for the mapper's output vector
|
61 |
+
target_norm: Optional[float] = None
|
62 |
+
# Whether to use positional encoding over the input to the mapper
|
63 |
+
use_positional_encoding: bool = True
|
64 |
+
# Sigmas used for computing positional encoding
|
65 |
+
pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0})
|
66 |
+
# Number of time anchors for computing our positional encodings
|
67 |
+
num_pe_time_anchors: int = 10
|
68 |
+
# Whether to output the textual bypass vector
|
69 |
+
output_bypass: bool = True
|
70 |
+
# Revision of pretrained model identifier from huggingface.co/models
|
71 |
+
revision: Optional[str] = None
|
72 |
+
# Whether training should be resumed from a previous checkpoint.
|
73 |
+
mapper_checkpoint_path: Optional[Path] = None
|
74 |
+
|
75 |
+
def __post_init__(self):
|
76 |
+
if self.pe_sigmas is not None:
|
77 |
+
assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!"
|
78 |
+
self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l'])
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class EvalConfig:
|
83 |
+
""" Parameters for validation """
|
84 |
+
# A list of prompts that will be used during validation to verify that the model is learning
|
85 |
+
validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS)
|
86 |
+
# Number of images that should be generated during validation with `validation_prompt`
|
87 |
+
num_validation_images: int = 4
|
88 |
+
# Seeds to use for generating the validation images
|
89 |
+
validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456])
|
90 |
+
# Run validation every X steps.
|
91 |
+
validation_steps: int = 100
|
92 |
+
# Number of denoising steps
|
93 |
+
num_denoising_steps: int = 50
|
94 |
+
|
95 |
+
def __post_init__(self):
|
96 |
+
if self.validation_seeds is None:
|
97 |
+
self.validation_seeds = list(range(self.num_validation_images))
|
98 |
+
assert len(self.validation_seeds) == self.num_validation_images, \
|
99 |
+
"Length of validation_seeds should equal num_validation_images"
|
100 |
+
|
101 |
+
@dataclass
|
102 |
+
class OptimConfig:
|
103 |
+
""" Parameters for the optimization process """
|
104 |
+
# Total number of training steps to perform.
|
105 |
+
max_train_steps: Optional[int] = 1_000
|
106 |
+
# Learning rate
|
107 |
+
learning_rate: float = 1e-3
|
108 |
+
# Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size
|
109 |
+
scale_lr: bool = True
|
110 |
+
# Batch size (per device) for the training dataloader
|
111 |
+
train_batch_size: int = 2
|
112 |
+
# Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass
|
113 |
+
gradient_checkpointing: bool = False
|
114 |
+
# Number of updates steps to accumulate before performing a backward/update pass
|
115 |
+
gradient_accumulation_steps: int = 4
|
116 |
+
# A seed for reproducible training
|
117 |
+
seed: Optional[int] = None
|
118 |
+
# The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
|
119 |
+
# "constant", "constant_with_warmup"]
|
120 |
+
lr_scheduler: str = "constant"
|
121 |
+
# Number of steps for the warmup in the lr scheduler
|
122 |
+
lr_warmup_steps: int = 0
|
123 |
+
# The beta1 parameter for the Adam optimizer
|
124 |
+
adam_beta1: float = 0.9
|
125 |
+
# The beta2 parameter for the Adam optimizer
|
126 |
+
adam_beta2: float = 0.999
|
127 |
+
# Weight decay to use
|
128 |
+
adam_weight_decay: float = 1e-2
|
129 |
+
# Epsilon value for the Adam optimizer
|
130 |
+
adam_epsilon: float = 1e-08
|
131 |
+
# Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.
|
132 |
+
# and an Nvidia Ampere GPU.
|
133 |
+
mixed_precision: str = "no"
|
134 |
+
# Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
|
135 |
+
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
136 |
+
allow_tf32: bool = False
|
137 |
+
|
138 |
+
|
139 |
+
@dataclass
|
140 |
+
class RunConfig:
|
141 |
+
""" The main configuration for the coach trainer """
|
142 |
+
log: LogConfig = field(default_factory=LogConfig)
|
143 |
+
data: DataConfig = field(default_factory=DataConfig)
|
144 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
145 |
+
eval: EvalConfig = field(default_factory=EvalConfig)
|
146 |
+
optim: OptimConfig = field(default_factory=OptimConfig)
|
src/constants.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
UNET_LAYERS = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID',
|
2 |
+
'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
|
3 |
+
|
4 |
+
SD_INFERENCE_TIMESTEPS = [999, 979, 959, 939, 919, 899, 879, 859, 839, 819, 799, 779, 759, 739, 719, 699, 679, 659,
|
5 |
+
639, 619, 599, 579, 559, 539, 519, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300,
|
6 |
+
280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20]
|
7 |
+
|
8 |
+
PROMPTS = [
|
9 |
+
"A photo of a {}",
|
10 |
+
"A photo of {} in the jungle",
|
11 |
+
"A photo of {} on a beach",
|
12 |
+
"A photo of {} in Times Square",
|
13 |
+
"A photo of {} in the moon",
|
14 |
+
"A painting of {} in the style of Monet",
|
15 |
+
"Oil painting of {}",
|
16 |
+
"A Marc Chagall painting of {}",
|
17 |
+
"A manga drawing of {}",
|
18 |
+
'A watercolor painting of {}',
|
19 |
+
"A statue of {}",
|
20 |
+
"App icon of {}",
|
21 |
+
"A sand sculpture of {}",
|
22 |
+
"Colorful graffiti of {}",
|
23 |
+
"A photograph of two {} on a table",
|
24 |
+
]
|
25 |
+
|
26 |
+
VALIDATION_PROMPTS = [
|
27 |
+
"A photo of a {}",
|
28 |
+
"A photo of a {} on a beach",
|
29 |
+
"App icon of {}",
|
30 |
+
"A painting of {} in the style of Monet",
|
31 |
+
]
|
32 |
+
|
33 |
+
IMAGENET_TEMPLATES_SMALL = [
|
34 |
+
"a photo of a {}",
|
35 |
+
"a rendering of a {}",
|
36 |
+
"a cropped photo of the {}",
|
37 |
+
"the photo of a {}",
|
38 |
+
"a photo of a clean {}",
|
39 |
+
"a photo of a dirty {}",
|
40 |
+
"a dark photo of the {}",
|
41 |
+
"a photo of my {}",
|
42 |
+
"a photo of the cool {}",
|
43 |
+
"a close-up photo of a {}",
|
44 |
+
"a bright photo of the {}",
|
45 |
+
"a cropped photo of a {}",
|
46 |
+
"a photo of the {}",
|
47 |
+
"a good photo of the {}",
|
48 |
+
"a photo of one {}",
|
49 |
+
"a close-up photo of the {}",
|
50 |
+
"a rendition of the {}",
|
51 |
+
"a photo of the clean {}",
|
52 |
+
"a rendition of a {}",
|
53 |
+
"a photo of a nice {}",
|
54 |
+
"a good photo of a {}",
|
55 |
+
"a photo of the nice {}",
|
56 |
+
"a photo of the small {}",
|
57 |
+
"a photo of the weird {}",
|
58 |
+
"a photo of the large {}",
|
59 |
+
"a photo of a cool {}",
|
60 |
+
"a photo of a small {}",
|
61 |
+
]
|
62 |
+
|
63 |
+
IMAGENET_STYLE_TEMPLATES_SMALL = [
|
64 |
+
"a painting in the style of {}",
|
65 |
+
"a rendering in the style of {}",
|
66 |
+
"a cropped painting in the style of {}",
|
67 |
+
"the painting in the style of {}",
|
68 |
+
"a clean painting in the style of {}",
|
69 |
+
"a dirty painting in the style of {}",
|
70 |
+
"a dark painting in the style of {}",
|
71 |
+
"a picture in the style of {}",
|
72 |
+
"a cool painting in the style of {}",
|
73 |
+
"a close-up painting in the style of {}",
|
74 |
+
"a bright painting in the style of {}",
|
75 |
+
"a cropped painting in the style of {}",
|
76 |
+
"a good painting in the style of {}",
|
77 |
+
"a close-up painting in the style of {}",
|
78 |
+
"a rendition in the style of {}",
|
79 |
+
"a nice painting in the style of {}",
|
80 |
+
"a small painting in the style of {}",
|
81 |
+
"a weird painting in the style of {}",
|
82 |
+
"a large painting in the style of {}",
|
83 |
+
]
|
src/models/__init__.py
ADDED
File without changes
|
src/models/net_clip_text_embedding.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from transformers import CLIPTextConfig
|
6 |
+
|
7 |
+
from src.models.neti_mapper import NeTIMapper
|
8 |
+
from src.utils.types import NeTIBatch
|
9 |
+
|
10 |
+
|
11 |
+
class NeTICLIPTextEmbeddings(nn.Module):
|
12 |
+
""" Modification of CLIPTextEmbedding to allow for the use of a NeTIMapper to overwrite the concept token. """
|
13 |
+
|
14 |
+
def __init__(self, config: CLIPTextConfig):
|
15 |
+
super().__init__()
|
16 |
+
embed_dim = config.hidden_size
|
17 |
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
18 |
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
19 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
20 |
+
|
21 |
+
def set_mapper(self, mapper: NeTIMapper):
|
22 |
+
self.mapper = mapper
|
23 |
+
|
24 |
+
def forward(self, input_ids: Optional[torch.LongTensor] = None,
|
25 |
+
position_ids: Optional[torch.LongTensor] = None,
|
26 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
27 |
+
batch: Optional[NeTIBatch] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
28 |
+
|
29 |
+
if batch is not None:
|
30 |
+
input_ids = batch.input_ids
|
31 |
+
|
32 |
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
33 |
+
|
34 |
+
if position_ids is None:
|
35 |
+
position_ids = self.position_ids[:, :seq_length]
|
36 |
+
|
37 |
+
if inputs_embeds is None:
|
38 |
+
inputs_embeds = self.token_embedding(input_ids)
|
39 |
+
|
40 |
+
####################################################################
|
41 |
+
# NeTI logic - Use mapper to overwrite the learnable token embedding
|
42 |
+
####################################################################
|
43 |
+
bypass_outputs = None
|
44 |
+
if batch is not None:
|
45 |
+
mapper_outputs = self.mapper(timestep=batch.timesteps.float(),
|
46 |
+
unet_layer=batch.unet_layers.float(),
|
47 |
+
truncation_idx=batch.truncation_idx)
|
48 |
+
mapper_outputs = mapper_outputs.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
|
49 |
+
if self.mapper.output_bypass:
|
50 |
+
bypass_outputs = mapper_outputs[:, mapper_outputs.shape[1] // 2:]
|
51 |
+
mapper_outputs = mapper_outputs[:, :mapper_outputs.shape[1] // 2]
|
52 |
+
|
53 |
+
# Overwrite the index of the placeholder token with the mapper output for each entry in the batch
|
54 |
+
learnable_idxs = (input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
|
55 |
+
inputs_embeds[torch.arange(input_ids.shape[0]), learnable_idxs] = mapper_outputs
|
56 |
+
|
57 |
+
position_embeddings = self.position_embedding(position_ids)
|
58 |
+
embeddings = inputs_embeds + position_embeddings
|
59 |
+
|
60 |
+
return embeddings, bypass_outputs
|
src/models/neti_clip_text_encoder.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from torch import nn
|
6 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7 |
+
from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPEncoder
|
8 |
+
from transformers.models.clip.modeling_clip import CLIPTextTransformer, _expand_mask
|
9 |
+
|
10 |
+
from src.models.net_clip_text_embedding import NeTICLIPTextEmbeddings
|
11 |
+
from src.utils.types import NeTIBatch
|
12 |
+
|
13 |
+
|
14 |
+
class NeTICLIPTextModel(CLIPTextModel):
|
15 |
+
""" Modification of CLIPTextModel to use our NeTI mapper for computing the embeddings of the concept. """
|
16 |
+
|
17 |
+
def __init__(self, config: CLIPTextConfig):
|
18 |
+
super().__init__(config)
|
19 |
+
self.text_model = NeTICLIPTextTransformer(config)
|
20 |
+
self.post_init()
|
21 |
+
|
22 |
+
def forward(self, input_ids: Optional[torch.Tensor] = None,
|
23 |
+
attention_mask: Optional[torch.Tensor] = None,
|
24 |
+
position_ids: Optional[torch.Tensor] = None,
|
25 |
+
output_attentions: Optional[bool] = None,
|
26 |
+
output_hidden_states: Optional[bool] = None,
|
27 |
+
return_dict: Optional[bool] = None,
|
28 |
+
batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
|
29 |
+
return self.text_model.forward(
|
30 |
+
batch=batch,
|
31 |
+
input_ids=input_ids,
|
32 |
+
attention_mask=attention_mask,
|
33 |
+
position_ids=position_ids,
|
34 |
+
output_attentions=output_attentions,
|
35 |
+
output_hidden_states=output_hidden_states,
|
36 |
+
return_dict=return_dict,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class NeTICLIPTextTransformer(CLIPTextTransformer):
|
41 |
+
""" Modification of CLIPTextTransformer to use our NeTI mapper for computing the embeddings of the concept. """
|
42 |
+
|
43 |
+
def __init__(self, config: CLIPTextConfig):
|
44 |
+
super().__init__(config=config)
|
45 |
+
self.config = config
|
46 |
+
embed_dim = config.hidden_size
|
47 |
+
self.embeddings = NeTICLIPTextEmbeddings(config)
|
48 |
+
self.encoder = CLIPEncoder(config)
|
49 |
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
50 |
+
|
51 |
+
def forward(self, input_ids: Optional[torch.Tensor] = None,
|
52 |
+
attention_mask: Optional[torch.Tensor] = None,
|
53 |
+
position_ids: Optional[torch.Tensor] = None,
|
54 |
+
output_attentions: Optional[bool] = None,
|
55 |
+
output_hidden_states: Optional[bool] = None,
|
56 |
+
return_dict: Optional[bool] = None,
|
57 |
+
batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
|
58 |
+
|
59 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
60 |
+
output_hidden_states = (
|
61 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
62 |
+
)
|
63 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
64 |
+
|
65 |
+
bypass_output = None
|
66 |
+
|
67 |
+
if input_ids is not None: # Regular embedding logic
|
68 |
+
input_shape = input_ids.size()
|
69 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
70 |
+
hidden_states, _ = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
71 |
+
|
72 |
+
###########################
|
73 |
+
# NeTI logic
|
74 |
+
###########################
|
75 |
+
elif batch is not None:
|
76 |
+
input_shape = batch.input_ids.size()
|
77 |
+
batch.input_ids = batch.input_ids.view(-1, input_shape[-1])
|
78 |
+
hidden_states, bypass_output = self.embeddings(batch=batch, position_ids=position_ids)
|
79 |
+
|
80 |
+
else:
|
81 |
+
raise ValueError("You have to specify either batch or input_ids!")
|
82 |
+
|
83 |
+
bsz, seq_len = input_shape
|
84 |
+
# CLIP's text model uses causal mask, prepare it here.
|
85 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
86 |
+
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
87 |
+
hidden_states.device
|
88 |
+
)
|
89 |
+
|
90 |
+
# expand attention_mask
|
91 |
+
if attention_mask is not None:
|
92 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
93 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
94 |
+
|
95 |
+
encoder_outputs = self.encoder(
|
96 |
+
inputs_embeds=hidden_states,
|
97 |
+
attention_mask=attention_mask,
|
98 |
+
causal_attention_mask=causal_attention_mask,
|
99 |
+
output_attentions=output_attentions,
|
100 |
+
output_hidden_states=output_hidden_states,
|
101 |
+
return_dict=return_dict,
|
102 |
+
)
|
103 |
+
|
104 |
+
last_hidden_state = encoder_outputs[0]
|
105 |
+
last_hidden_state_with_bypass = last_hidden_state.clone()
|
106 |
+
|
107 |
+
###############################################
|
108 |
+
# NeTI logic - compute the scaled bypass output
|
109 |
+
###############################################
|
110 |
+
if bypass_output is not None:
|
111 |
+
learnable_idxs = (batch.input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
|
112 |
+
existing_state = last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs]
|
113 |
+
bypass_output = bypass_output / bypass_output.norm(dim=1, keepdim=True) \
|
114 |
+
* existing_state.norm(dim=1, keepdim=True)
|
115 |
+
new_state = existing_state + 0.2 * bypass_output
|
116 |
+
new_state = new_state.to(dtype=hidden_states.dtype)
|
117 |
+
last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] = new_state
|
118 |
+
|
119 |
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
120 |
+
last_hidden_state_with_bypass = self.final_layer_norm(last_hidden_state_with_bypass)
|
121 |
+
|
122 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
123 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
124 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
125 |
+
if input_ids is not None:
|
126 |
+
pooled_output = last_hidden_state[
|
127 |
+
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
128 |
+
]
|
129 |
+
pooled_output_with_bypass = last_hidden_state_with_bypass[
|
130 |
+
torch.arange(last_hidden_state_with_bypass.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
|
131 |
+
]
|
132 |
+
elif batch is not None:
|
133 |
+
pooled_output = last_hidden_state[
|
134 |
+
torch.arange(last_hidden_state.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
|
135 |
+
]
|
136 |
+
pooled_output_with_bypass = last_hidden_state_with_bypass[
|
137 |
+
torch.arange(last_hidden_state_with_bypass.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
|
138 |
+
]
|
139 |
+
else:
|
140 |
+
raise ValueError("You have to specify either batch or input_ids!")
|
141 |
+
|
142 |
+
if bypass_output is not None:
|
143 |
+
return BaseModelOutputWithPooling(
|
144 |
+
last_hidden_state=last_hidden_state,
|
145 |
+
pooler_output=pooled_output,
|
146 |
+
hidden_states=encoder_outputs.hidden_states,
|
147 |
+
attentions=encoder_outputs.attentions,
|
148 |
+
), BaseModelOutputWithPooling(
|
149 |
+
last_hidden_state=last_hidden_state_with_bypass,
|
150 |
+
pooler_output=pooled_output_with_bypass,
|
151 |
+
hidden_states=encoder_outputs.hidden_states,
|
152 |
+
attentions=encoder_outputs.attentions,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
return BaseModelOutputWithPooling(
|
156 |
+
last_hidden_state=last_hidden_state,
|
157 |
+
pooler_output=pooled_output,
|
158 |
+
hidden_states=encoder_outputs.hidden_states,
|
159 |
+
attentions=encoder_outputs.attentions,
|
160 |
+
), None
|
src/models/neti_mapper.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from src.constants import UNET_LAYERS
|
9 |
+
from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
|
10 |
+
from src.utils.types import PESigmas
|
11 |
+
|
12 |
+
|
13 |
+
class NeTIMapper(nn.Module):
|
14 |
+
""" Main logic of our NeTI mapper. """
|
15 |
+
|
16 |
+
def __init__(self, output_dim: int = 768,
|
17 |
+
unet_layers: List[str] = UNET_LAYERS,
|
18 |
+
use_nested_dropout: bool = True,
|
19 |
+
nested_dropout_prob: float = 0.5,
|
20 |
+
norm_scale: Optional[torch.Tensor] = None,
|
21 |
+
use_positional_encoding: bool = True,
|
22 |
+
num_pe_time_anchors: int = 10,
|
23 |
+
pe_sigmas: PESigmas = PESigmas(sigma_t=0.03, sigma_l=2.0),
|
24 |
+
output_bypass: bool = True):
|
25 |
+
super().__init__()
|
26 |
+
self.use_nested_dropout = use_nested_dropout
|
27 |
+
self.nested_dropout_prob = nested_dropout_prob
|
28 |
+
self.norm_scale = norm_scale
|
29 |
+
self.output_bypass = output_bypass
|
30 |
+
if self.output_bypass:
|
31 |
+
output_dim *= 2 # Output two vectors
|
32 |
+
|
33 |
+
self.use_positional_encoding = use_positional_encoding
|
34 |
+
if self.use_positional_encoding:
|
35 |
+
self.encoder = NeTIPositionalEncoding(sigma_t=pe_sigmas.sigma_t, sigma_l=pe_sigmas.sigma_l).cuda()
|
36 |
+
self.input_dim = num_pe_time_anchors * len(unet_layers)
|
37 |
+
else:
|
38 |
+
self.encoder = BasicEncoder().cuda()
|
39 |
+
self.input_dim = 2
|
40 |
+
|
41 |
+
self.set_net(num_unet_layers=len(unet_layers),
|
42 |
+
num_time_anchors=num_pe_time_anchors,
|
43 |
+
output_dim=output_dim)
|
44 |
+
|
45 |
+
def set_net(self, num_unet_layers: int, num_time_anchors: int, output_dim: int = 768):
|
46 |
+
self.input_layer = self.set_input_layer(num_unet_layers, num_time_anchors)
|
47 |
+
self.net = nn.Sequential(self.input_layer,
|
48 |
+
nn.Linear(self.input_dim, 128), nn.LayerNorm(128), nn.LeakyReLU(),
|
49 |
+
nn.Linear(128, 128), nn.LayerNorm(128), nn.LeakyReLU())
|
50 |
+
self.output_layer = nn.Sequential(nn.Linear(128, output_dim))
|
51 |
+
|
52 |
+
def set_input_layer(self, num_unet_layers: int, num_time_anchors: int) -> nn.Module:
|
53 |
+
if self.use_positional_encoding:
|
54 |
+
input_layer = nn.Linear(self.encoder.num_w * 2, self.input_dim)
|
55 |
+
input_layer.weight.data = self.encoder.init_layer(num_time_anchors, num_unet_layers)
|
56 |
+
else:
|
57 |
+
input_layer = nn.Identity()
|
58 |
+
return input_layer
|
59 |
+
|
60 |
+
def forward(self, timestep: torch.Tensor, unet_layer: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
|
61 |
+
embedding = self.extract_hidden_representation(timestep, unet_layer)
|
62 |
+
if self.use_nested_dropout:
|
63 |
+
embedding = self.apply_nested_dropout(embedding, truncation_idx=truncation_idx)
|
64 |
+
embedding = self.get_output(embedding)
|
65 |
+
return embedding
|
66 |
+
|
67 |
+
def get_encoded_input(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
|
68 |
+
return self.encoder.encode(timestep, unet_layer)
|
69 |
+
|
70 |
+
def extract_hidden_representation(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
|
71 |
+
encoded_input = self.get_encoded_input(timestep, unet_layer)
|
72 |
+
embedding = self.net(encoded_input)
|
73 |
+
return embedding
|
74 |
+
|
75 |
+
def apply_nested_dropout(self, embedding: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
|
76 |
+
if self.training:
|
77 |
+
if random.random() < self.nested_dropout_prob:
|
78 |
+
dropout_idxs = torch.randint(low=0, high=embedding.shape[1], size=(embedding.shape[0],))
|
79 |
+
for idx in torch.arange(embedding.shape[0]):
|
80 |
+
embedding[idx][dropout_idxs[idx]:] = 0
|
81 |
+
if not self.training and truncation_idx is not None:
|
82 |
+
for idx in torch.arange(embedding.shape[0]):
|
83 |
+
embedding[idx][truncation_idx:] = 0
|
84 |
+
return embedding
|
85 |
+
|
86 |
+
def get_output(self, embedding: torch.Tensor) -> torch.Tensor:
|
87 |
+
embedding = self.output_layer(embedding)
|
88 |
+
if self.norm_scale is not None:
|
89 |
+
embedding = F.normalize(embedding, dim=-1) * self.norm_scale
|
90 |
+
return embedding
|
src/models/positional_encoding.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class NeTIPositionalEncoding(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, sigma_t: float, sigma_l: float, num_w: int = 1024):
|
10 |
+
super().__init__()
|
11 |
+
self.sigma_t = sigma_t
|
12 |
+
self.sigma_l = sigma_l
|
13 |
+
self.num_w = num_w
|
14 |
+
self.w = torch.randn((num_w, 2))
|
15 |
+
self.w[:, 0] *= sigma_t
|
16 |
+
self.w[:, 1] *= sigma_l
|
17 |
+
self.w = nn.Parameter(self.w).cuda()
|
18 |
+
|
19 |
+
def encode(self, t: Union[int, torch.Tensor], l: Union[int, torch.Tensor]):
|
20 |
+
""" Maps the given time and layer input into a 2048-dimensional vector. """
|
21 |
+
if type(t) == int or t.ndim == 0:
|
22 |
+
x = torch.tensor([t, l]).float()
|
23 |
+
else:
|
24 |
+
x = torch.stack([t, l], dim=1).T
|
25 |
+
x = x.cuda()
|
26 |
+
v = torch.cat([torch.sin(self.w.detach() @ x), torch.cos(self.w.detach() @ x)])
|
27 |
+
if type(t) == int:
|
28 |
+
v_norm = v / v.norm()
|
29 |
+
else:
|
30 |
+
v_norm = v / v.norm(dim=0)
|
31 |
+
v_norm = v_norm.T
|
32 |
+
return v_norm
|
33 |
+
|
34 |
+
def init_layer(self, num_time_anchors: int, num_layers: int) -> torch.Tensor:
|
35 |
+
""" Computes the weights for the positional encoding layer of size 160x2048."""
|
36 |
+
anchor_vectors = []
|
37 |
+
for t_anchor in range(0, 1000, 1000 // num_time_anchors):
|
38 |
+
for l_anchor in range(0, num_layers):
|
39 |
+
anchor_vectors.append(self.encode(t_anchor, l_anchor).float())
|
40 |
+
A = torch.stack(anchor_vectors)
|
41 |
+
return A
|
42 |
+
|
43 |
+
|
44 |
+
class BasicEncoder(nn.Module):
|
45 |
+
""" Simply normalizes the given timestep and unet layer to be between -1 and 1. """
|
46 |
+
|
47 |
+
def __init__(self, num_denoising_timesteps: int = 1000, num_unet_layers: int = 16):
|
48 |
+
super().__init__()
|
49 |
+
self.normalized_timesteps = (torch.arange(num_denoising_timesteps) / (num_denoising_timesteps - 1)) * 2 - 1
|
50 |
+
self.normalized_unet_layers = (torch.arange(num_unet_layers) / (num_unet_layers - 1)) * 2 - 1
|
51 |
+
self.normalized_timesteps = nn.Parameter(self.normalized_timesteps).cuda()
|
52 |
+
self.normalized_unet_layers = nn.Parameter(self.normalized_unet_layers).cuda()
|
53 |
+
|
54 |
+
def encode(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
|
55 |
+
normalized_input = torch.stack([self.normalized_timesteps[timestep.long()],
|
56 |
+
self.normalized_unet_layers[unet_layer.long()]]).T
|
57 |
+
return normalized_input
|
src/models/xti_attention_processor.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.models.cross_attention import CrossAttention
|
5 |
+
|
6 |
+
|
7 |
+
class XTIAttenProc:
|
8 |
+
|
9 |
+
def __call__(self, attn: CrossAttention,
|
10 |
+
hidden_states: torch.Tensor,
|
11 |
+
encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None,
|
12 |
+
attention_mask: Optional[torch.Tensor] = None):
|
13 |
+
|
14 |
+
_ehs_bypass = None
|
15 |
+
if encoder_hidden_states is not None:
|
16 |
+
if isinstance(encoder_hidden_states, dict):
|
17 |
+
this_idx = encoder_hidden_states["this_idx"]
|
18 |
+
_ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
|
19 |
+
if f"CONTEXT_TENSOR_BYPASS_{this_idx}" in encoder_hidden_states:
|
20 |
+
_ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS_{this_idx}"]
|
21 |
+
encoder_hidden_states["this_idx"] += 1
|
22 |
+
encoder_hidden_states["this_idx"] %= 16
|
23 |
+
else:
|
24 |
+
_ehs = encoder_hidden_states
|
25 |
+
else:
|
26 |
+
_ehs = None
|
27 |
+
|
28 |
+
batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape)
|
29 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
30 |
+
query = attn.to_q(hidden_states)
|
31 |
+
|
32 |
+
if _ehs is None:
|
33 |
+
_ehs = hidden_states
|
34 |
+
elif attn.cross_attention_norm:
|
35 |
+
_ehs = attn.norm_cross(_ehs)
|
36 |
+
_ehs_bypass = attn.norm_cross(_ehs_bypass)
|
37 |
+
|
38 |
+
key = attn.to_k(_ehs)
|
39 |
+
if _ehs_bypass is not None:
|
40 |
+
value = attn.to_v(_ehs_bypass)
|
41 |
+
else:
|
42 |
+
value = attn.to_v(_ehs)
|
43 |
+
|
44 |
+
query = attn.head_to_batch_dim(query)
|
45 |
+
key = attn.head_to_batch_dim(key)
|
46 |
+
value = attn.head_to_batch_dim(value)
|
47 |
+
|
48 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
49 |
+
hidden_states = torch.bmm(attention_probs, value)
|
50 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
51 |
+
|
52 |
+
# linear proj
|
53 |
+
hidden_states = attn.to_out[0](hidden_states)
|
54 |
+
# dropout
|
55 |
+
hidden_states = attn.to_out[1](hidden_states)
|
56 |
+
|
57 |
+
return hidden_states
|
src/prompt_manager.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Dict, Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
from transformers import CLIPTokenizer
|
6 |
+
|
7 |
+
from src import constants
|
8 |
+
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
|
9 |
+
from src.utils.types import NeTIBatch
|
10 |
+
|
11 |
+
|
12 |
+
class PromptManager:
|
13 |
+
""" Class for computing all time and space embeddings for a given prompt. """
|
14 |
+
def __init__(self, tokenizer: CLIPTokenizer,
|
15 |
+
text_encoder: NeTICLIPTextModel,
|
16 |
+
timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
|
17 |
+
unet_layers: List[str] = constants.UNET_LAYERS,
|
18 |
+
placeholder_token_id: Optional[List] = None,
|
19 |
+
placeholder_token: Optional[List] = None,
|
20 |
+
torch_dtype: torch.dtype = torch.float32):
|
21 |
+
self.tokenizer = tokenizer
|
22 |
+
self.text_encoder = text_encoder
|
23 |
+
self.timesteps = timesteps
|
24 |
+
self.unet_layers = unet_layers
|
25 |
+
self.placeholder_token = placeholder_token
|
26 |
+
self.placeholder_token_id = placeholder_token_id
|
27 |
+
self.dtype = torch_dtype
|
28 |
+
|
29 |
+
def embed_prompt(self, text: str,
|
30 |
+
truncation_idx: Optional[int] = None,
|
31 |
+
num_images_per_prompt: int = 1) -> List[Dict[str, Any]]:
|
32 |
+
"""
|
33 |
+
Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}`
|
34 |
+
for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples.
|
35 |
+
"""
|
36 |
+
text = text.format(self.placeholder_token)
|
37 |
+
ids = self.tokenizer(
|
38 |
+
text,
|
39 |
+
padding="max_length",
|
40 |
+
max_length=self.tokenizer.model_max_length,
|
41 |
+
return_tensors="pt",
|
42 |
+
).input_ids
|
43 |
+
|
44 |
+
# Compute embeddings for each timestep and each U-Net layer
|
45 |
+
print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.")
|
46 |
+
hidden_states_per_timestep = []
|
47 |
+
for timestep in tqdm(self.timesteps):
|
48 |
+
_hs = {"this_idx": 0}.copy()
|
49 |
+
for layer_idx, unet_layer in enumerate(self.unet_layers):
|
50 |
+
batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device),
|
51 |
+
timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device),
|
52 |
+
unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0),
|
53 |
+
placeholder_token_id=self.placeholder_token_id,
|
54 |
+
truncation_idx=truncation_idx)
|
55 |
+
layer_hs, layer_hs_bypass = self.text_encoder(batch=batch)
|
56 |
+
layer_hs = layer_hs[0].to(dtype=self.dtype)
|
57 |
+
_hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1)
|
58 |
+
if layer_hs_bypass is not None:
|
59 |
+
layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype)
|
60 |
+
_hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1)
|
61 |
+
hidden_states_per_timestep.append(_hs)
|
62 |
+
print("Done.")
|
63 |
+
return hidden_states_per_timestep
|
src/scripts/__init__.py
ADDED
File without changes
|
src/scripts/inference.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, List, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pyrallis
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
|
11 |
+
from transformers import CLIPTokenizer
|
12 |
+
|
13 |
+
sys.path.append(".")
|
14 |
+
sys.path.append("..")
|
15 |
+
|
16 |
+
from src import constants
|
17 |
+
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
|
18 |
+
from src.models.neti_mapper import NeTIMapper
|
19 |
+
from src.prompt_manager import PromptManager
|
20 |
+
from src.sd_pipeline_call import sd_pipeline_call
|
21 |
+
from src.models.xti_attention_processor import XTIAttenProc
|
22 |
+
from src.checkpoint_handler import CheckpointHandler
|
23 |
+
from src.utils import vis_utils
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class InferenceConfig:
|
28 |
+
# Specifies which checkpoint iteration we want to load
|
29 |
+
iteration: Optional[int] = None
|
30 |
+
# The input directory containing the saved models and embeddings
|
31 |
+
input_dir: Optional[Path] = None
|
32 |
+
# Where the save the inference results to
|
33 |
+
inference_dir: Optional[Path] = None
|
34 |
+
# Specific path to the mapper you want to load, overrides `input_dir`
|
35 |
+
mapper_checkpoint_path: Optional[Path] = None
|
36 |
+
# Specific path to the embeddings you want to load, overrides `input_dir`
|
37 |
+
learned_embeds_path: Optional[Path] = None
|
38 |
+
# List of prompts to run inference on
|
39 |
+
prompts: Optional[List[str]] = None
|
40 |
+
# Text file containing a prompts to run inference on (one prompt per line), overrides `prompts`
|
41 |
+
prompts_file_path: Optional[Path] = None
|
42 |
+
# List of random seeds to run on
|
43 |
+
seeds: List[int] = field(default_factory=lambda: [42])
|
44 |
+
# If you want to run with dropout at inference time, this specifies the truncation indices for applying dropout.
|
45 |
+
# None indicates that no dropout will be performed. If a list of indices is provided, will run all indices.
|
46 |
+
truncation_idxs: Optional[Union[int, List[int]]] = None
|
47 |
+
# Whether to run with torch.float16 or torch.float32
|
48 |
+
torch_dtype: str = "fp16"
|
49 |
+
|
50 |
+
def __post_init__(self):
|
51 |
+
assert bool(self.prompts) != bool(self.prompts_file_path), \
|
52 |
+
"You must provide either prompts or prompts_file_path, but not both!"
|
53 |
+
self._set_prompts()
|
54 |
+
self._set_input_paths()
|
55 |
+
self.inference_dir.mkdir(exist_ok=True, parents=True)
|
56 |
+
if type(self.truncation_idxs) == int:
|
57 |
+
self.truncation_idxs = [self.truncation_idxs]
|
58 |
+
self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32
|
59 |
+
|
60 |
+
def _set_input_paths(self):
|
61 |
+
if self.inference_dir is None:
|
62 |
+
assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir"
|
63 |
+
self.inference_dir = self.input_dir / f"inference_{self.iteration}"
|
64 |
+
if self.mapper_checkpoint_path is None:
|
65 |
+
assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path"
|
66 |
+
self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt"
|
67 |
+
if self.learned_embeds_path is None:
|
68 |
+
assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path"
|
69 |
+
self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin"
|
70 |
+
|
71 |
+
def _set_prompts(self):
|
72 |
+
if self.prompts_file_path is not None:
|
73 |
+
assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!"
|
74 |
+
self.prompts = self.prompts_file_path.read_text().splitlines()
|
75 |
+
|
76 |
+
|
77 |
+
@pyrallis.wrap()
|
78 |
+
def main(infer_cfg: InferenceConfig):
|
79 |
+
train_cfg, mapper = CheckpointHandler.load_mapper(infer_cfg.mapper_checkpoint_path)
|
80 |
+
pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model(
|
81 |
+
pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path,
|
82 |
+
mapper=mapper,
|
83 |
+
learned_embeds_path=infer_cfg.learned_embeds_path,
|
84 |
+
torch_dtype=infer_cfg.torch_dtype
|
85 |
+
)
|
86 |
+
prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
|
87 |
+
text_encoder=pipeline.text_encoder,
|
88 |
+
timesteps=pipeline.scheduler.timesteps,
|
89 |
+
unet_layers=constants.UNET_LAYERS,
|
90 |
+
placeholder_token=placeholder_token,
|
91 |
+
placeholder_token_id=placeholder_token_id,
|
92 |
+
torch_dtype=infer_cfg.torch_dtype)
|
93 |
+
for prompt in infer_cfg.prompts:
|
94 |
+
output_path = infer_cfg.inference_dir / prompt.format(placeholder_token)
|
95 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
96 |
+
for truncation_idx in infer_cfg.truncation_idxs:
|
97 |
+
print(f"Running with truncation index: {truncation_idx}")
|
98 |
+
prompt_image = run_inference(prompt=prompt,
|
99 |
+
pipeline=pipeline,
|
100 |
+
prompt_manager=prompt_manager,
|
101 |
+
seeds=infer_cfg.seeds,
|
102 |
+
output_path=output_path,
|
103 |
+
num_images_per_prompt=1,
|
104 |
+
truncation_idx=truncation_idx)
|
105 |
+
if truncation_idx is not None:
|
106 |
+
save_name = f"{prompt.format(placeholder_token)}_truncation_{truncation_idx}.png"
|
107 |
+
else:
|
108 |
+
save_name = f"{prompt.format(placeholder_token)}.png"
|
109 |
+
prompt_image.save(infer_cfg.inference_dir / save_name)
|
110 |
+
|
111 |
+
|
112 |
+
def run_inference(prompt: str,
|
113 |
+
pipeline: StableDiffusionPipeline,
|
114 |
+
prompt_manager: PromptManager,
|
115 |
+
seeds: List[int],
|
116 |
+
output_path: Optional[Path] = None,
|
117 |
+
num_images_per_prompt: int = 1,
|
118 |
+
truncation_idx: Optional[int] = None) -> Image.Image:
|
119 |
+
with torch.autocast("cuda"):
|
120 |
+
with torch.no_grad():
|
121 |
+
prompt_embeds = prompt_manager.embed_prompt(prompt,
|
122 |
+
num_images_per_prompt=num_images_per_prompt,
|
123 |
+
truncation_idx=truncation_idx)
|
124 |
+
joined_images = []
|
125 |
+
for seed in seeds:
|
126 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
127 |
+
images = sd_pipeline_call(pipeline,
|
128 |
+
prompt_embeds=prompt_embeds,
|
129 |
+
generator=generator,
|
130 |
+
num_images_per_prompt=num_images_per_prompt).images
|
131 |
+
seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB")
|
132 |
+
if output_path is not None:
|
133 |
+
save_name = f'{seed}_truncation_{truncation_idx}.png' if truncation_idx is not None else f'{seed}.png'
|
134 |
+
seed_image.save(output_path / save_name)
|
135 |
+
joined_images.append(seed_image)
|
136 |
+
joined_image = vis_utils.get_image_grid(joined_images)
|
137 |
+
return joined_image
|
138 |
+
|
139 |
+
|
140 |
+
def load_stable_diffusion_model(pretrained_model_name_or_path: str,
|
141 |
+
learned_embeds_path: Path,
|
142 |
+
mapper: Optional[NeTIMapper] = None,
|
143 |
+
num_denoising_steps: int = 50,
|
144 |
+
torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]:
|
145 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
146 |
+
pretrained_model_name_or_path, subfolder="tokenizer")
|
147 |
+
text_encoder = NeTICLIPTextModel.from_pretrained(
|
148 |
+
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
|
149 |
+
)
|
150 |
+
if mapper is not None:
|
151 |
+
text_encoder.text_model.embeddings.set_mapper(mapper)
|
152 |
+
placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
|
153 |
+
learned_embeds_path=learned_embeds_path,
|
154 |
+
text_encoder=text_encoder,
|
155 |
+
tokenizer=tokenizer
|
156 |
+
)
|
157 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
158 |
+
pretrained_model_name_or_path,
|
159 |
+
torch_dtype=torch_dtype,
|
160 |
+
text_encoder=text_encoder,
|
161 |
+
tokenizer=tokenizer
|
162 |
+
).to("cuda")
|
163 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
164 |
+
pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
|
165 |
+
pipeline.unet.set_attn_processor(XTIAttenProc())
|
166 |
+
return pipeline, placeholder_token, placeholder_token_id
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == '__main__':
|
170 |
+
main()
|
src/sd_pipeline_call.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def sd_pipeline_call(
|
9 |
+
pipeline: StableDiffusionPipeline,
|
10 |
+
prompt_embeds: torch.FloatTensor,
|
11 |
+
height: Optional[int] = None,
|
12 |
+
width: Optional[int] = None,
|
13 |
+
num_inference_steps: int = 50,
|
14 |
+
guidance_scale: float = 7.5,
|
15 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
16 |
+
num_images_per_prompt: Optional[int] = 1,
|
17 |
+
eta: float = 0.0,
|
18 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
19 |
+
latents: Optional[torch.FloatTensor] = None,
|
20 |
+
output_type: Optional[str] = "pil",
|
21 |
+
return_dict: bool = True,
|
22 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
23 |
+
callback_steps: int = 1,
|
24 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
|
25 |
+
""" Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
|
26 |
+
|
27 |
+
# 0. Default height and width to unet
|
28 |
+
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
29 |
+
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
30 |
+
|
31 |
+
# 2. Define call parameters
|
32 |
+
batch_size = 1
|
33 |
+
device = pipeline._execution_device
|
34 |
+
|
35 |
+
neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
|
36 |
+
negative_prompt_embeds, _ = pipeline.text_encoder(
|
37 |
+
input_ids=neg_prompt.input_ids.to(device),
|
38 |
+
attention_mask=None,
|
39 |
+
)
|
40 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
41 |
+
|
42 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
43 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
44 |
+
# corresponds to doing no classifier free guidance.
|
45 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
46 |
+
|
47 |
+
# 4. Prepare timesteps
|
48 |
+
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
49 |
+
timesteps = pipeline.scheduler.timesteps
|
50 |
+
|
51 |
+
# 5. Prepare latent variables
|
52 |
+
num_channels_latents = pipeline.unet.in_channels
|
53 |
+
latents = pipeline.prepare_latents(
|
54 |
+
batch_size * num_images_per_prompt,
|
55 |
+
num_channels_latents,
|
56 |
+
height,
|
57 |
+
width,
|
58 |
+
pipeline.text_encoder.dtype,
|
59 |
+
device,
|
60 |
+
generator,
|
61 |
+
latents,
|
62 |
+
)
|
63 |
+
|
64 |
+
# 6. Prepare extra step kwargs.
|
65 |
+
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
66 |
+
|
67 |
+
# 7. Denoising loop
|
68 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
|
69 |
+
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
|
70 |
+
for i, t in enumerate(timesteps):
|
71 |
+
|
72 |
+
if do_classifier_free_guidance:
|
73 |
+
latent_model_input = latents
|
74 |
+
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
|
75 |
+
|
76 |
+
# predict the noise residual
|
77 |
+
noise_pred_uncond = pipeline.unet(
|
78 |
+
latent_model_input,
|
79 |
+
t,
|
80 |
+
encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
|
81 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
82 |
+
).sample
|
83 |
+
|
84 |
+
###############################################################
|
85 |
+
# NeTI logic: use the prompt embedding for the current timestep
|
86 |
+
###############################################################
|
87 |
+
embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
|
88 |
+
noise_pred_text = pipeline.unet(
|
89 |
+
latent_model_input,
|
90 |
+
t,
|
91 |
+
encoder_hidden_states=embed,
|
92 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
93 |
+
).sample
|
94 |
+
|
95 |
+
# perform guidance
|
96 |
+
if do_classifier_free_guidance:
|
97 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
98 |
+
|
99 |
+
# compute the previous noisy sample x_t -> x_t-1
|
100 |
+
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
101 |
+
|
102 |
+
# call the callback, if provided
|
103 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
|
104 |
+
progress_bar.update()
|
105 |
+
if callback is not None and i % callback_steps == 0:
|
106 |
+
callback(i, t, latents)
|
107 |
+
|
108 |
+
if output_type == "latent":
|
109 |
+
image = latents
|
110 |
+
has_nsfw_concept = None
|
111 |
+
elif output_type == "pil":
|
112 |
+
# 8. Post-processing
|
113 |
+
image = pipeline.decode_latents(latents)
|
114 |
+
# 9. Run safety checker
|
115 |
+
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
|
116 |
+
# 10. Convert to PIL
|
117 |
+
image = pipeline.numpy_to_pil(image)
|
118 |
+
else:
|
119 |
+
# 8. Post-processing
|
120 |
+
image = pipeline.decode_latents(latents)
|
121 |
+
# 9. Run safety checker
|
122 |
+
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
|
123 |
+
|
124 |
+
# Offload last model to CPU
|
125 |
+
if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
|
126 |
+
pipeline.final_offload_hook.offload()
|
127 |
+
|
128 |
+
if not return_dict:
|
129 |
+
return image, has_nsfw_concept
|
130 |
+
|
131 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
132 |
+
|
133 |
+
|
134 |
+
def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
|
135 |
+
negative_prompt: Optional[Union[str, List[str]]] = None):
|
136 |
+
if negative_prompt is None:
|
137 |
+
negative_prompt = ""
|
138 |
+
uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
139 |
+
uncond_input = pipeline.tokenizer(
|
140 |
+
uncond_tokens,
|
141 |
+
padding="max_length",
|
142 |
+
max_length=pipeline.tokenizer.model_max_length,
|
143 |
+
truncation=True,
|
144 |
+
return_tensors="pt",
|
145 |
+
)
|
146 |
+
return uncond_input
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/types.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class NeTIBatch:
|
10 |
+
input_ids: torch.Tensor
|
11 |
+
placeholder_token_id: int
|
12 |
+
timesteps: torch.Tensor
|
13 |
+
unet_layers: torch.Tensor
|
14 |
+
truncation_idx: Optional[int] = None
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class PESigmas:
|
19 |
+
sigma_t: float
|
20 |
+
sigma_l: float
|
src/utils/vis_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def get_image_grid(images: List[Image.Image]) -> Image:
|
8 |
+
num_images = len(images)
|
9 |
+
cols = int(math.ceil(math.sqrt(num_images)))
|
10 |
+
rows = int(math.ceil(num_images / cols))
|
11 |
+
width, height = images[0].size
|
12 |
+
grid_image = Image.new('RGB', (cols * width, rows * height))
|
13 |
+
for i, img in enumerate(images):
|
14 |
+
x = i % cols
|
15 |
+
y = i // cols
|
16 |
+
grid_image.paste(img, (x * width, y * height))
|
17 |
+
return grid_image
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|