|
from typing import Optional, List, Dict, Any |
|
|
|
import torch |
|
from tqdm import tqdm |
|
from transformers import CLIPTokenizer |
|
|
|
from src import constants |
|
from src.models.neti_clip_text_encoder import NeTICLIPTextModel |
|
from src.utils.types import NeTIBatch |
|
|
|
|
|
class PromptManager: |
|
""" Class for computing all time and space embeddings for a given prompt. """ |
|
def __init__(self, tokenizer: CLIPTokenizer, |
|
text_encoder: NeTICLIPTextModel, |
|
timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS, |
|
unet_layers: List[str] = constants.UNET_LAYERS, |
|
placeholder_token_id: Optional[List] = None, |
|
placeholder_token: Optional[List] = None, |
|
torch_dtype: torch.dtype = torch.float32): |
|
self.tokenizer = tokenizer |
|
self.text_encoder = text_encoder |
|
self.timesteps = timesteps |
|
self.unet_layers = unet_layers |
|
self.placeholder_token = placeholder_token |
|
self.placeholder_token_id = placeholder_token_id |
|
self.dtype = torch_dtype |
|
|
|
def embed_prompt(self, text: str, |
|
truncation_idx: Optional[int] = None, |
|
num_images_per_prompt: int = 1) -> List[Dict[str, Any]]: |
|
""" |
|
Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}` |
|
for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples. |
|
""" |
|
text = text.format(self.placeholder_token) |
|
ids = self.tokenizer( |
|
text, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
return_tensors="pt", |
|
).input_ids |
|
|
|
|
|
print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.") |
|
hidden_states_per_timestep = [] |
|
for timestep in tqdm(self.timesteps): |
|
_hs = {"this_idx": 0}.copy() |
|
for layer_idx, unet_layer in enumerate(self.unet_layers): |
|
batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device), |
|
timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device), |
|
unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0), |
|
placeholder_token_id=self.placeholder_token_id, |
|
truncation_idx=truncation_idx) |
|
layer_hs, layer_hs_bypass = self.text_encoder(batch=batch) |
|
layer_hs = layer_hs[0].to(dtype=self.dtype) |
|
_hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1) |
|
if layer_hs_bypass is not None: |
|
layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype) |
|
_hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1) |
|
hidden_states_per_timestep.append(_hs) |
|
print("Done.") |
|
return hidden_states_per_timestep |
|
|