File size: 3,159 Bytes
ebb9992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Optional, List, Dict, Any

import torch
from tqdm import tqdm
from transformers import CLIPTokenizer

from src import constants
from src.models.neti_clip_text_encoder import NeTICLIPTextModel
from src.utils.types import NeTIBatch


class PromptManager:
    """ Class for computing all time and space embeddings for a given prompt. """
    def __init__(self, tokenizer: CLIPTokenizer,
                 text_encoder: NeTICLIPTextModel,
                 timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
                 unet_layers: List[str] = constants.UNET_LAYERS,
                 placeholder_token_id: Optional[List] = None,
                 placeholder_token: Optional[List] = None,
                 torch_dtype: torch.dtype = torch.float32):
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.timesteps = timesteps
        self.unet_layers = unet_layers
        self.placeholder_token = placeholder_token
        self.placeholder_token_id = placeholder_token_id
        self.dtype = torch_dtype

    def embed_prompt(self, text: str,
                     truncation_idx: Optional[int] = None,
                     num_images_per_prompt: int = 1) -> List[Dict[str, Any]]:
        """
        Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}`
        for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples.
        """
        text = text.format(self.placeholder_token)
        ids = self.tokenizer(
            text,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        # Compute embeddings for each timestep and each U-Net layer
        print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.")
        hidden_states_per_timestep = []
        for timestep in tqdm(self.timesteps):
            _hs = {"this_idx": 0}.copy()
            for layer_idx, unet_layer in enumerate(self.unet_layers):
                batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device),
                                  timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device),
                                  unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0),
                                  placeholder_token_id=self.placeholder_token_id,
                                  truncation_idx=truncation_idx)
                layer_hs, layer_hs_bypass = self.text_encoder(batch=batch)
                layer_hs = layer_hs[0].to(dtype=self.dtype)
                _hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1)
                if layer_hs_bypass is not None:
                    layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype)
                    _hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1)
            hidden_states_per_timestep.append(_hs)
        print("Done.")
        return hidden_states_per_timestep