NeTI / src /prompt_manager.py
neural-ti's picture
Add demo
ebb9992
raw
history blame
3.16 kB
from typing import Optional, List, Dict, Any
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
from src import constants
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
from src.utils.types import NeTIBatch
class PromptManager:
""" Class for computing all time and space embeddings for a given prompt. """
def __init__(self, tokenizer: CLIPTokenizer,
text_encoder: NeTICLIPTextModel,
timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
unet_layers: List[str] = constants.UNET_LAYERS,
placeholder_token_id: Optional[List] = None,
placeholder_token: Optional[List] = None,
torch_dtype: torch.dtype = torch.float32):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.timesteps = timesteps
self.unet_layers = unet_layers
self.placeholder_token = placeholder_token
self.placeholder_token_id = placeholder_token_id
self.dtype = torch_dtype
def embed_prompt(self, text: str,
truncation_idx: Optional[int] = None,
num_images_per_prompt: int = 1) -> List[Dict[str, Any]]:
"""
Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}`
for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples.
"""
text = text.format(self.placeholder_token)
ids = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
# Compute embeddings for each timestep and each U-Net layer
print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.")
hidden_states_per_timestep = []
for timestep in tqdm(self.timesteps):
_hs = {"this_idx": 0}.copy()
for layer_idx, unet_layer in enumerate(self.unet_layers):
batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device),
timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device),
unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0),
placeholder_token_id=self.placeholder_token_id,
truncation_idx=truncation_idx)
layer_hs, layer_hs_bypass = self.text_encoder(batch=batch)
layer_hs = layer_hs[0].to(dtype=self.dtype)
_hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1)
if layer_hs_bypass is not None:
layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype)
_hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1)
hidden_states_per_timestep.append(_hs)
print("Done.")
return hidden_states_per_timestep