neural-ti commited on
Commit
1f53cbd
1 Parent(s): 891bcde

Delete src

Browse files
src/__init__.py DELETED
File without changes
src/checkpoint_handler.py DELETED
@@ -1,107 +0,0 @@
1
- from pathlib import Path
2
- from typing import Tuple
3
-
4
- import pyrallis
5
- import torch
6
- from accelerate import Accelerator
7
- from torch import nn
8
- from transformers import CLIPTokenizer
9
-
10
- from src.models.neti_clip_text_encoder import NeTICLIPTextModel
11
- from src.models.neti_mapper import NeTIMapper
12
- from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
13
- from src.config import RunConfig
14
-
15
-
16
- class CheckpointHandler:
17
-
18
- def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path):
19
- self.cfg = cfg
20
- self.placeholder_token_string = placeholder_token_string
21
- self.placeholder_token_id = placeholder_token_id
22
- self.save_root = save_root
23
-
24
- def save_model(self, text_encoder: NeTICLIPTextModel,
25
- accelerator: Accelerator,
26
- embeds_save_name: str,
27
- mapper_save_name: str):
28
- self.save_learned_embeds(text_encoder, accelerator, embeds_save_name)
29
- self.save_mapper(text_encoder, mapper_save_name)
30
-
31
- def save_learned_embeds(self, text_encoder: NeTICLIPTextModel, accelerator: Accelerator, save_name: str):
32
- """
33
- Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference
34
- to take the place of our placeholder token.
35
- """
36
- learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id]
37
- learned_embeds = learned_embeds.detach().cpu()
38
- learned_embeds_dict = {self.placeholder_token_string: learned_embeds}
39
- torch.save(learned_embeds_dict, self.save_root / save_name)
40
-
41
- def save_mapper(self, text_encoder: NeTICLIPTextModel, save_name: str):
42
- """ Save the mapper and config to be used at inference. """
43
- cfg_ = RunConfig(**self.cfg.__dict__.copy())
44
- state_dict = {
45
- "state_dict": text_encoder.text_model.embeddings.mapper.state_dict(),
46
- "cfg": pyrallis.encode(cfg_),
47
- "encoder": text_encoder.text_model.embeddings.mapper.encoder
48
- }
49
- torch.save(state_dict, self.save_root / save_name)
50
-
51
- @staticmethod
52
- def load_mapper(mapper_path: Path) -> Tuple[RunConfig, NeTIMapper]:
53
- mapper_ckpt = torch.load(mapper_path, map_location="cpu")
54
- cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg'])
55
- neti_mapper = NeTIMapper(output_dim=768,
56
- use_nested_dropout=cfg.model.use_nested_dropout,
57
- nested_dropout_prob=cfg.model.nested_dropout_prob,
58
- norm_scale=cfg.model.target_norm,
59
- use_positional_encoding=cfg.model.use_positional_encoding,
60
- num_pe_time_anchors=cfg.model.num_pe_time_anchors,
61
- pe_sigmas=cfg.model.pe_sigmas,
62
- output_bypass=cfg.model.output_bypass)
63
- neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True)
64
- encoder = mapper_ckpt['encoder']
65
- if isinstance(encoder, NeTIPositionalEncoding):
66
- encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda())
67
- elif isinstance(encoder, BasicEncoder):
68
- encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda()
69
- encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda()
70
- neti_mapper.encoder = encoder.cuda()
71
- neti_mapper.cuda()
72
- neti_mapper.eval()
73
- return cfg, neti_mapper
74
-
75
- @staticmethod
76
- def load_learned_embed_in_clip(learned_embeds_path: Path,
77
- text_encoder: NeTICLIPTextModel,
78
- tokenizer: CLIPTokenizer) -> Tuple[str, int]:
79
- loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
80
-
81
- # separate token and the embeds
82
- trained_tokens = list(loaded_learned_embeds.keys())
83
- embeds = list(loaded_learned_embeds.values())
84
-
85
- # cast to dtype of text_encoder
86
- dtype = text_encoder.get_input_embeddings().weight.dtype
87
- embeds = [e.to(dtype) for e in embeds]
88
-
89
- # add the tokens in tokenizer
90
- num_added_tokens = tokenizer.add_tokens(trained_tokens)
91
- if num_added_tokens == 0:
92
- raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. "
93
- f"Please pass a different `token` that is not already in the tokenizer.")
94
-
95
- # resize the token embeddings
96
- text_encoder.resize_token_embeddings(len(tokenizer))
97
-
98
- # get the id for the token and assign the embeds
99
- placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens]
100
-
101
- for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)):
102
- text_encoder.get_input_embeddings().weight.data[token_id] = embed
103
-
104
- assert len(trained_tokens) == 1, "Only one placeholder token is supported"
105
- placeholder_token = trained_tokens[0]
106
- placeholder_token_id = placeholder_token_ids[0]
107
- return placeholder_token, placeholder_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config.py DELETED
@@ -1,146 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from pathlib import Path
3
- from typing import List, Optional, Dict
4
-
5
- from src.constants import VALIDATION_PROMPTS
6
- from src.utils.types import PESigmas
7
-
8
-
9
- @dataclass
10
- class LogConfig:
11
- """ Parameters for logging and saving """
12
- # Name of experiment. This will be the name of the output folder
13
- exp_name: str
14
- # The output directory where the model predictions and checkpoints will be written
15
- exp_dir: Path = Path("./outputs")
16
- # Save interval
17
- save_steps: int = 250
18
- # [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
19
- # `output_dir/runs/**CURRENT_DATETIME_HOSTNAME`
20
- logging_dir: Path = Path("logs")
21
- # The integration to report the results to. Supported platforms are "tensorboard" '
22
- # (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
23
- report_to: str = "tensorboard"
24
- # Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator`
25
- checkpoints_total_limit: Optional[int] = None
26
-
27
-
28
- @dataclass
29
- class DataConfig:
30
- """ Parameters for data """
31
- # A folder containing the training data
32
- train_data_dir: Path
33
- # A token to use as a placeholder for the concept
34
- placeholder_token: str
35
- # Super category token to use for normalizing the mapper output
36
- super_category_token: Optional[str] = "object"
37
- # Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process
38
- dataloader_num_workers: int = 8
39
- # Choose between 'object' and 'style' - used for selecting the prompts for training
40
- learnable_property: str = "object"
41
- # How many times to repeat the training data
42
- repeats: int = 100
43
- # The resolution for input images, all the images in the train/validation dataset will be resized to this resolution
44
- resolution: int = 512
45
- # Whether to center crop images before resizing to resolution
46
- center_crop: bool = False
47
-
48
-
49
- @dataclass
50
- class ModelConfig:
51
- """ Parameters for defining all models """
52
- # Path to pretrained model or model identifier from huggingface.co/models
53
- pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4"
54
- # Whether to use our Nested Dropout technique
55
- use_nested_dropout: bool = True
56
- # Probability to apply nested dropout during training
57
- nested_dropout_prob: float = 0.5
58
- # Whether to normalize the norm of the mapper's output vector
59
- normalize_mapper_output: bool = True
60
- # Target norm for the mapper's output vector
61
- target_norm: Optional[float] = None
62
- # Whether to use positional encoding over the input to the mapper
63
- use_positional_encoding: bool = True
64
- # Sigmas used for computing positional encoding
65
- pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0})
66
- # Number of time anchors for computing our positional encodings
67
- num_pe_time_anchors: int = 10
68
- # Whether to output the textual bypass vector
69
- output_bypass: bool = True
70
- # Revision of pretrained model identifier from huggingface.co/models
71
- revision: Optional[str] = None
72
- # Whether training should be resumed from a previous checkpoint.
73
- mapper_checkpoint_path: Optional[Path] = None
74
-
75
- def __post_init__(self):
76
- if self.pe_sigmas is not None:
77
- assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!"
78
- self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l'])
79
-
80
-
81
- @dataclass
82
- class EvalConfig:
83
- """ Parameters for validation """
84
- # A list of prompts that will be used during validation to verify that the model is learning
85
- validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS)
86
- # Number of images that should be generated during validation with `validation_prompt`
87
- num_validation_images: int = 4
88
- # Seeds to use for generating the validation images
89
- validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456])
90
- # Run validation every X steps.
91
- validation_steps: int = 100
92
- # Number of denoising steps
93
- num_denoising_steps: int = 50
94
-
95
- def __post_init__(self):
96
- if self.validation_seeds is None:
97
- self.validation_seeds = list(range(self.num_validation_images))
98
- assert len(self.validation_seeds) == self.num_validation_images, \
99
- "Length of validation_seeds should equal num_validation_images"
100
-
101
- @dataclass
102
- class OptimConfig:
103
- """ Parameters for the optimization process """
104
- # Total number of training steps to perform.
105
- max_train_steps: Optional[int] = 1_000
106
- # Learning rate
107
- learning_rate: float = 1e-3
108
- # Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size
109
- scale_lr: bool = True
110
- # Batch size (per device) for the training dataloader
111
- train_batch_size: int = 2
112
- # Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass
113
- gradient_checkpointing: bool = False
114
- # Number of updates steps to accumulate before performing a backward/update pass
115
- gradient_accumulation_steps: int = 4
116
- # A seed for reproducible training
117
- seed: Optional[int] = None
118
- # The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
119
- # "constant", "constant_with_warmup"]
120
- lr_scheduler: str = "constant"
121
- # Number of steps for the warmup in the lr scheduler
122
- lr_warmup_steps: int = 0
123
- # The beta1 parameter for the Adam optimizer
124
- adam_beta1: float = 0.9
125
- # The beta2 parameter for the Adam optimizer
126
- adam_beta2: float = 0.999
127
- # Weight decay to use
128
- adam_weight_decay: float = 1e-2
129
- # Epsilon value for the Adam optimizer
130
- adam_epsilon: float = 1e-08
131
- # Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.
132
- # and an Nvidia Ampere GPU.
133
- mixed_precision: str = "no"
134
- # Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
135
- # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
136
- allow_tf32: bool = False
137
-
138
-
139
- @dataclass
140
- class RunConfig:
141
- """ The main configuration for the coach trainer """
142
- log: LogConfig = field(default_factory=LogConfig)
143
- data: DataConfig = field(default_factory=DataConfig)
144
- model: ModelConfig = field(default_factory=ModelConfig)
145
- eval: EvalConfig = field(default_factory=EvalConfig)
146
- optim: OptimConfig = field(default_factory=OptimConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/constants.py DELETED
@@ -1,83 +0,0 @@
1
- UNET_LAYERS = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID',
2
- 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
3
-
4
- SD_INFERENCE_TIMESTEPS = [999, 979, 959, 939, 919, 899, 879, 859, 839, 819, 799, 779, 759, 739, 719, 699, 679, 659,
5
- 639, 619, 599, 579, 559, 539, 519, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300,
6
- 280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20]
7
-
8
- PROMPTS = [
9
- "A photo of a {}",
10
- "A photo of {} in the jungle",
11
- "A photo of {} on a beach",
12
- "A photo of {} in Times Square",
13
- "A photo of {} in the moon",
14
- "A painting of {} in the style of Monet",
15
- "Oil painting of {}",
16
- "A Marc Chagall painting of {}",
17
- "A manga drawing of {}",
18
- 'A watercolor painting of {}',
19
- "A statue of {}",
20
- "App icon of {}",
21
- "A sand sculpture of {}",
22
- "Colorful graffiti of {}",
23
- "A photograph of two {} on a table",
24
- ]
25
-
26
- VALIDATION_PROMPTS = [
27
- "A photo of a {}",
28
- "A photo of a {} on a beach",
29
- "App icon of {}",
30
- "A painting of {} in the style of Monet",
31
- ]
32
-
33
- IMAGENET_TEMPLATES_SMALL = [
34
- "a photo of a {}",
35
- "a rendering of a {}",
36
- "a cropped photo of the {}",
37
- "the photo of a {}",
38
- "a photo of a clean {}",
39
- "a photo of a dirty {}",
40
- "a dark photo of the {}",
41
- "a photo of my {}",
42
- "a photo of the cool {}",
43
- "a close-up photo of a {}",
44
- "a bright photo of the {}",
45
- "a cropped photo of a {}",
46
- "a photo of the {}",
47
- "a good photo of the {}",
48
- "a photo of one {}",
49
- "a close-up photo of the {}",
50
- "a rendition of the {}",
51
- "a photo of the clean {}",
52
- "a rendition of a {}",
53
- "a photo of a nice {}",
54
- "a good photo of a {}",
55
- "a photo of the nice {}",
56
- "a photo of the small {}",
57
- "a photo of the weird {}",
58
- "a photo of the large {}",
59
- "a photo of a cool {}",
60
- "a photo of a small {}",
61
- ]
62
-
63
- IMAGENET_STYLE_TEMPLATES_SMALL = [
64
- "a painting in the style of {}",
65
- "a rendering in the style of {}",
66
- "a cropped painting in the style of {}",
67
- "the painting in the style of {}",
68
- "a clean painting in the style of {}",
69
- "a dirty painting in the style of {}",
70
- "a dark painting in the style of {}",
71
- "a picture in the style of {}",
72
- "a cool painting in the style of {}",
73
- "a close-up painting in the style of {}",
74
- "a bright painting in the style of {}",
75
- "a cropped painting in the style of {}",
76
- "a good painting in the style of {}",
77
- "a close-up painting in the style of {}",
78
- "a rendition in the style of {}",
79
- "a nice painting in the style of {}",
80
- "a small painting in the style of {}",
81
- "a weird painting in the style of {}",
82
- "a large painting in the style of {}",
83
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/__init__.py DELETED
File without changes
src/models/net_clip_text_embedding.py DELETED
@@ -1,60 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from transformers import CLIPTextConfig
6
-
7
- from src.models.neti_mapper import NeTIMapper
8
- from src.utils.types import NeTIBatch
9
-
10
-
11
- class NeTICLIPTextEmbeddings(nn.Module):
12
- """ Modification of CLIPTextEmbedding to allow for the use of a NeTIMapper to overwrite the concept token. """
13
-
14
- def __init__(self, config: CLIPTextConfig):
15
- super().__init__()
16
- embed_dim = config.hidden_size
17
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
18
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
19
- self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
20
-
21
- def set_mapper(self, mapper: NeTIMapper):
22
- self.mapper = mapper
23
-
24
- def forward(self, input_ids: Optional[torch.LongTensor] = None,
25
- position_ids: Optional[torch.LongTensor] = None,
26
- inputs_embeds: Optional[torch.FloatTensor] = None,
27
- batch: Optional[NeTIBatch] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
28
-
29
- if batch is not None:
30
- input_ids = batch.input_ids
31
-
32
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
33
-
34
- if position_ids is None:
35
- position_ids = self.position_ids[:, :seq_length]
36
-
37
- if inputs_embeds is None:
38
- inputs_embeds = self.token_embedding(input_ids)
39
-
40
- ####################################################################
41
- # NeTI logic - Use mapper to overwrite the learnable token embedding
42
- ####################################################################
43
- bypass_outputs = None
44
- if batch is not None:
45
- mapper_outputs = self.mapper(timestep=batch.timesteps.float(),
46
- unet_layer=batch.unet_layers.float(),
47
- truncation_idx=batch.truncation_idx)
48
- mapper_outputs = mapper_outputs.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
49
- if self.mapper.output_bypass:
50
- bypass_outputs = mapper_outputs[:, mapper_outputs.shape[1] // 2:]
51
- mapper_outputs = mapper_outputs[:, :mapper_outputs.shape[1] // 2]
52
-
53
- # Overwrite the index of the placeholder token with the mapper output for each entry in the batch
54
- learnable_idxs = (input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
55
- inputs_embeds[torch.arange(input_ids.shape[0]), learnable_idxs] = mapper_outputs
56
-
57
- position_embeddings = self.position_embedding(position_ids)
58
- embeddings = inputs_embeds + position_embeddings
59
-
60
- return embeddings, bypass_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/neti_clip_text_encoder.py DELETED
@@ -1,160 +0,0 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.utils.checkpoint
5
- from torch import nn
6
- from transformers.modeling_outputs import BaseModelOutputWithPooling
7
- from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPEncoder
8
- from transformers.models.clip.modeling_clip import CLIPTextTransformer, _expand_mask
9
-
10
- from src.models.net_clip_text_embedding import NeTICLIPTextEmbeddings
11
- from src.utils.types import NeTIBatch
12
-
13
-
14
- class NeTICLIPTextModel(CLIPTextModel):
15
- """ Modification of CLIPTextModel to use our NeTI mapper for computing the embeddings of the concept. """
16
-
17
- def __init__(self, config: CLIPTextConfig):
18
- super().__init__(config)
19
- self.text_model = NeTICLIPTextTransformer(config)
20
- self.post_init()
21
-
22
- def forward(self, input_ids: Optional[torch.Tensor] = None,
23
- attention_mask: Optional[torch.Tensor] = None,
24
- position_ids: Optional[torch.Tensor] = None,
25
- output_attentions: Optional[bool] = None,
26
- output_hidden_states: Optional[bool] = None,
27
- return_dict: Optional[bool] = None,
28
- batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
29
- return self.text_model.forward(
30
- batch=batch,
31
- input_ids=input_ids,
32
- attention_mask=attention_mask,
33
- position_ids=position_ids,
34
- output_attentions=output_attentions,
35
- output_hidden_states=output_hidden_states,
36
- return_dict=return_dict,
37
- )
38
-
39
-
40
- class NeTICLIPTextTransformer(CLIPTextTransformer):
41
- """ Modification of CLIPTextTransformer to use our NeTI mapper for computing the embeddings of the concept. """
42
-
43
- def __init__(self, config: CLIPTextConfig):
44
- super().__init__(config=config)
45
- self.config = config
46
- embed_dim = config.hidden_size
47
- self.embeddings = NeTICLIPTextEmbeddings(config)
48
- self.encoder = CLIPEncoder(config)
49
- self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
50
-
51
- def forward(self, input_ids: Optional[torch.Tensor] = None,
52
- attention_mask: Optional[torch.Tensor] = None,
53
- position_ids: Optional[torch.Tensor] = None,
54
- output_attentions: Optional[bool] = None,
55
- output_hidden_states: Optional[bool] = None,
56
- return_dict: Optional[bool] = None,
57
- batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
58
-
59
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
60
- output_hidden_states = (
61
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
62
- )
63
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
-
65
- bypass_output = None
66
-
67
- if input_ids is not None: # Regular embedding logic
68
- input_shape = input_ids.size()
69
- input_ids = input_ids.view(-1, input_shape[-1])
70
- hidden_states, _ = self.embeddings(input_ids=input_ids, position_ids=position_ids)
71
-
72
- ###########################
73
- # NeTI logic
74
- ###########################
75
- elif batch is not None:
76
- input_shape = batch.input_ids.size()
77
- batch.input_ids = batch.input_ids.view(-1, input_shape[-1])
78
- hidden_states, bypass_output = self.embeddings(batch=batch, position_ids=position_ids)
79
-
80
- else:
81
- raise ValueError("You have to specify either batch or input_ids!")
82
-
83
- bsz, seq_len = input_shape
84
- # CLIP's text model uses causal mask, prepare it here.
85
- # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
86
- causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
87
- hidden_states.device
88
- )
89
-
90
- # expand attention_mask
91
- if attention_mask is not None:
92
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
93
- attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
94
-
95
- encoder_outputs = self.encoder(
96
- inputs_embeds=hidden_states,
97
- attention_mask=attention_mask,
98
- causal_attention_mask=causal_attention_mask,
99
- output_attentions=output_attentions,
100
- output_hidden_states=output_hidden_states,
101
- return_dict=return_dict,
102
- )
103
-
104
- last_hidden_state = encoder_outputs[0]
105
- last_hidden_state_with_bypass = last_hidden_state.clone()
106
-
107
- ###############################################
108
- # NeTI logic - compute the scaled bypass output
109
- ###############################################
110
- if bypass_output is not None:
111
- learnable_idxs = (batch.input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
112
- existing_state = last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs]
113
- bypass_output = bypass_output / bypass_output.norm(dim=1, keepdim=True) \
114
- * existing_state.norm(dim=1, keepdim=True)
115
- new_state = existing_state + 0.2 * bypass_output
116
- new_state = new_state.to(dtype=hidden_states.dtype)
117
- last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] = new_state
118
-
119
- last_hidden_state = self.final_layer_norm(last_hidden_state)
120
- last_hidden_state_with_bypass = self.final_layer_norm(last_hidden_state_with_bypass)
121
-
122
- # text_embeds.shape = [batch_size, sequence_length, transformer.width]
123
- # take features from the eot embedding (eot_token is the highest number in each sequence)
124
- # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
125
- if input_ids is not None:
126
- pooled_output = last_hidden_state[
127
- torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
128
- ]
129
- pooled_output_with_bypass = last_hidden_state_with_bypass[
130
- torch.arange(last_hidden_state_with_bypass.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
131
- ]
132
- elif batch is not None:
133
- pooled_output = last_hidden_state[
134
- torch.arange(last_hidden_state.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
135
- ]
136
- pooled_output_with_bypass = last_hidden_state_with_bypass[
137
- torch.arange(last_hidden_state_with_bypass.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
138
- ]
139
- else:
140
- raise ValueError("You have to specify either batch or input_ids!")
141
-
142
- if bypass_output is not None:
143
- return BaseModelOutputWithPooling(
144
- last_hidden_state=last_hidden_state,
145
- pooler_output=pooled_output,
146
- hidden_states=encoder_outputs.hidden_states,
147
- attentions=encoder_outputs.attentions,
148
- ), BaseModelOutputWithPooling(
149
- last_hidden_state=last_hidden_state_with_bypass,
150
- pooler_output=pooled_output_with_bypass,
151
- hidden_states=encoder_outputs.hidden_states,
152
- attentions=encoder_outputs.attentions,
153
- )
154
- else:
155
- return BaseModelOutputWithPooling(
156
- last_hidden_state=last_hidden_state,
157
- pooler_output=pooled_output,
158
- hidden_states=encoder_outputs.hidden_states,
159
- attentions=encoder_outputs.attentions,
160
- ), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/neti_mapper.py DELETED
@@ -1,90 +0,0 @@
1
- import random
2
- from typing import Optional, List
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
- from src.constants import UNET_LAYERS
9
- from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
10
- from src.utils.types import PESigmas
11
-
12
-
13
- class NeTIMapper(nn.Module):
14
- """ Main logic of our NeTI mapper. """
15
-
16
- def __init__(self, output_dim: int = 768,
17
- unet_layers: List[str] = UNET_LAYERS,
18
- use_nested_dropout: bool = True,
19
- nested_dropout_prob: float = 0.5,
20
- norm_scale: Optional[torch.Tensor] = None,
21
- use_positional_encoding: bool = True,
22
- num_pe_time_anchors: int = 10,
23
- pe_sigmas: PESigmas = PESigmas(sigma_t=0.03, sigma_l=2.0),
24
- output_bypass: bool = True):
25
- super().__init__()
26
- self.use_nested_dropout = use_nested_dropout
27
- self.nested_dropout_prob = nested_dropout_prob
28
- self.norm_scale = norm_scale
29
- self.output_bypass = output_bypass
30
- if self.output_bypass:
31
- output_dim *= 2 # Output two vectors
32
-
33
- self.use_positional_encoding = use_positional_encoding
34
- if self.use_positional_encoding:
35
- self.encoder = NeTIPositionalEncoding(sigma_t=pe_sigmas.sigma_t, sigma_l=pe_sigmas.sigma_l).cuda()
36
- self.input_dim = num_pe_time_anchors * len(unet_layers)
37
- else:
38
- self.encoder = BasicEncoder().cuda()
39
- self.input_dim = 2
40
-
41
- self.set_net(num_unet_layers=len(unet_layers),
42
- num_time_anchors=num_pe_time_anchors,
43
- output_dim=output_dim)
44
-
45
- def set_net(self, num_unet_layers: int, num_time_anchors: int, output_dim: int = 768):
46
- self.input_layer = self.set_input_layer(num_unet_layers, num_time_anchors)
47
- self.net = nn.Sequential(self.input_layer,
48
- nn.Linear(self.input_dim, 128), nn.LayerNorm(128), nn.LeakyReLU(),
49
- nn.Linear(128, 128), nn.LayerNorm(128), nn.LeakyReLU())
50
- self.output_layer = nn.Sequential(nn.Linear(128, output_dim))
51
-
52
- def set_input_layer(self, num_unet_layers: int, num_time_anchors: int) -> nn.Module:
53
- if self.use_positional_encoding:
54
- input_layer = nn.Linear(self.encoder.num_w * 2, self.input_dim)
55
- input_layer.weight.data = self.encoder.init_layer(num_time_anchors, num_unet_layers)
56
- else:
57
- input_layer = nn.Identity()
58
- return input_layer
59
-
60
- def forward(self, timestep: torch.Tensor, unet_layer: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
61
- embedding = self.extract_hidden_representation(timestep, unet_layer)
62
- if self.use_nested_dropout:
63
- embedding = self.apply_nested_dropout(embedding, truncation_idx=truncation_idx)
64
- embedding = self.get_output(embedding)
65
- return embedding
66
-
67
- def get_encoded_input(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
68
- return self.encoder.encode(timestep, unet_layer)
69
-
70
- def extract_hidden_representation(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
71
- encoded_input = self.get_encoded_input(timestep, unet_layer)
72
- embedding = self.net(encoded_input)
73
- return embedding
74
-
75
- def apply_nested_dropout(self, embedding: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
76
- if self.training:
77
- if random.random() < self.nested_dropout_prob:
78
- dropout_idxs = torch.randint(low=0, high=embedding.shape[1], size=(embedding.shape[0],))
79
- for idx in torch.arange(embedding.shape[0]):
80
- embedding[idx][dropout_idxs[idx]:] = 0
81
- if not self.training and truncation_idx is not None:
82
- for idx in torch.arange(embedding.shape[0]):
83
- embedding[idx][truncation_idx:] = 0
84
- return embedding
85
-
86
- def get_output(self, embedding: torch.Tensor) -> torch.Tensor:
87
- embedding = self.output_layer(embedding)
88
- if self.norm_scale is not None:
89
- embedding = F.normalize(embedding, dim=-1) * self.norm_scale
90
- return embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/positional_encoding.py DELETED
@@ -1,57 +0,0 @@
1
- from typing import Union
2
-
3
- import torch
4
- from torch import nn
5
-
6
-
7
- class NeTIPositionalEncoding(nn.Module):
8
-
9
- def __init__(self, sigma_t: float, sigma_l: float, num_w: int = 1024):
10
- super().__init__()
11
- self.sigma_t = sigma_t
12
- self.sigma_l = sigma_l
13
- self.num_w = num_w
14
- self.w = torch.randn((num_w, 2))
15
- self.w[:, 0] *= sigma_t
16
- self.w[:, 1] *= sigma_l
17
- self.w = nn.Parameter(self.w).cuda()
18
-
19
- def encode(self, t: Union[int, torch.Tensor], l: Union[int, torch.Tensor]):
20
- """ Maps the given time and layer input into a 2048-dimensional vector. """
21
- if type(t) == int or t.ndim == 0:
22
- x = torch.tensor([t, l]).float()
23
- else:
24
- x = torch.stack([t, l], dim=1).T
25
- x = x.cuda()
26
- v = torch.cat([torch.sin(self.w.detach() @ x), torch.cos(self.w.detach() @ x)])
27
- if type(t) == int:
28
- v_norm = v / v.norm()
29
- else:
30
- v_norm = v / v.norm(dim=0)
31
- v_norm = v_norm.T
32
- return v_norm
33
-
34
- def init_layer(self, num_time_anchors: int, num_layers: int) -> torch.Tensor:
35
- """ Computes the weights for the positional encoding layer of size 160x2048."""
36
- anchor_vectors = []
37
- for t_anchor in range(0, 1000, 1000 // num_time_anchors):
38
- for l_anchor in range(0, num_layers):
39
- anchor_vectors.append(self.encode(t_anchor, l_anchor).float())
40
- A = torch.stack(anchor_vectors)
41
- return A
42
-
43
-
44
- class BasicEncoder(nn.Module):
45
- """ Simply normalizes the given timestep and unet layer to be between -1 and 1. """
46
-
47
- def __init__(self, num_denoising_timesteps: int = 1000, num_unet_layers: int = 16):
48
- super().__init__()
49
- self.normalized_timesteps = (torch.arange(num_denoising_timesteps) / (num_denoising_timesteps - 1)) * 2 - 1
50
- self.normalized_unet_layers = (torch.arange(num_unet_layers) / (num_unet_layers - 1)) * 2 - 1
51
- self.normalized_timesteps = nn.Parameter(self.normalized_timesteps).cuda()
52
- self.normalized_unet_layers = nn.Parameter(self.normalized_unet_layers).cuda()
53
-
54
- def encode(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
55
- normalized_input = torch.stack([self.normalized_timesteps[timestep.long()],
56
- self.normalized_unet_layers[unet_layer.long()]]).T
57
- return normalized_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/xti_attention_processor.py DELETED
@@ -1,57 +0,0 @@
1
- from typing import Dict, Optional
2
-
3
- import torch
4
- from diffusers.models.cross_attention import CrossAttention
5
-
6
-
7
- class XTIAttenProc:
8
-
9
- def __call__(self, attn: CrossAttention,
10
- hidden_states: torch.Tensor,
11
- encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None,
12
- attention_mask: Optional[torch.Tensor] = None):
13
-
14
- _ehs_bypass = None
15
- if encoder_hidden_states is not None:
16
- if isinstance(encoder_hidden_states, dict):
17
- this_idx = encoder_hidden_states["this_idx"]
18
- _ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
19
- if f"CONTEXT_TENSOR_BYPASS_{this_idx}" in encoder_hidden_states:
20
- _ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS_{this_idx}"]
21
- encoder_hidden_states["this_idx"] += 1
22
- encoder_hidden_states["this_idx"] %= 16
23
- else:
24
- _ehs = encoder_hidden_states
25
- else:
26
- _ehs = None
27
-
28
- batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape)
29
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
30
- query = attn.to_q(hidden_states)
31
-
32
- if _ehs is None:
33
- _ehs = hidden_states
34
- elif attn.cross_attention_norm:
35
- _ehs = attn.norm_cross(_ehs)
36
- _ehs_bypass = attn.norm_cross(_ehs_bypass)
37
-
38
- key = attn.to_k(_ehs)
39
- if _ehs_bypass is not None:
40
- value = attn.to_v(_ehs_bypass)
41
- else:
42
- value = attn.to_v(_ehs)
43
-
44
- query = attn.head_to_batch_dim(query)
45
- key = attn.head_to_batch_dim(key)
46
- value = attn.head_to_batch_dim(value)
47
-
48
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
49
- hidden_states = torch.bmm(attention_probs, value)
50
- hidden_states = attn.batch_to_head_dim(hidden_states)
51
-
52
- # linear proj
53
- hidden_states = attn.to_out[0](hidden_states)
54
- # dropout
55
- hidden_states = attn.to_out[1](hidden_states)
56
-
57
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prompt_manager.py DELETED
@@ -1,63 +0,0 @@
1
- from typing import Optional, List, Dict, Any
2
-
3
- import torch
4
- from tqdm import tqdm
5
- from transformers import CLIPTokenizer
6
-
7
- from src import constants
8
- from src.models.neti_clip_text_encoder import NeTICLIPTextModel
9
- from src.utils.types import NeTIBatch
10
-
11
-
12
- class PromptManager:
13
- """ Class for computing all time and space embeddings for a given prompt. """
14
- def __init__(self, tokenizer: CLIPTokenizer,
15
- text_encoder: NeTICLIPTextModel,
16
- timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
17
- unet_layers: List[str] = constants.UNET_LAYERS,
18
- placeholder_token_id: Optional[List] = None,
19
- placeholder_token: Optional[List] = None,
20
- torch_dtype: torch.dtype = torch.float32):
21
- self.tokenizer = tokenizer
22
- self.text_encoder = text_encoder
23
- self.timesteps = timesteps
24
- self.unet_layers = unet_layers
25
- self.placeholder_token = placeholder_token
26
- self.placeholder_token_id = placeholder_token_id
27
- self.dtype = torch_dtype
28
-
29
- def embed_prompt(self, text: str,
30
- truncation_idx: Optional[int] = None,
31
- num_images_per_prompt: int = 1) -> List[Dict[str, Any]]:
32
- """
33
- Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}`
34
- for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples.
35
- """
36
- text = text.format(self.placeholder_token)
37
- ids = self.tokenizer(
38
- text,
39
- padding="max_length",
40
- max_length=self.tokenizer.model_max_length,
41
- return_tensors="pt",
42
- ).input_ids
43
-
44
- # Compute embeddings for each timestep and each U-Net layer
45
- print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.")
46
- hidden_states_per_timestep = []
47
- for timestep in tqdm(self.timesteps):
48
- _hs = {"this_idx": 0}.copy()
49
- for layer_idx, unet_layer in enumerate(self.unet_layers):
50
- batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device),
51
- timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device),
52
- unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0),
53
- placeholder_token_id=self.placeholder_token_id,
54
- truncation_idx=truncation_idx)
55
- layer_hs, layer_hs_bypass = self.text_encoder(batch=batch)
56
- layer_hs = layer_hs[0].to(dtype=self.dtype)
57
- _hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1)
58
- if layer_hs_bypass is not None:
59
- layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype)
60
- _hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1)
61
- hidden_states_per_timestep.append(_hs)
62
- print("Done.")
63
- return hidden_states_per_timestep
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/scripts/__init__.py DELETED
File without changes
src/scripts/inference.py DELETED
@@ -1,170 +0,0 @@
1
- import sys
2
- from dataclasses import dataclass, field
3
- from pathlib import Path
4
- from typing import Optional, List, Tuple, Union
5
-
6
- import numpy as np
7
- import pyrallis
8
- import torch
9
- from PIL import Image
10
- from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
11
- from transformers import CLIPTokenizer
12
-
13
- sys.path.append(".")
14
- sys.path.append("..")
15
-
16
- from src import constants
17
- from src.models.neti_clip_text_encoder import NeTICLIPTextModel
18
- from src.models.neti_mapper import NeTIMapper
19
- from src.prompt_manager import PromptManager
20
- from src.sd_pipeline_call import sd_pipeline_call
21
- from src.models.xti_attention_processor import XTIAttenProc
22
- from src.checkpoint_handler import CheckpointHandler
23
- from src.utils import vis_utils
24
-
25
-
26
- @dataclass
27
- class InferenceConfig:
28
- # Specifies which checkpoint iteration we want to load
29
- iteration: Optional[int] = None
30
- # The input directory containing the saved models and embeddings
31
- input_dir: Optional[Path] = None
32
- # Where the save the inference results to
33
- inference_dir: Optional[Path] = None
34
- # Specific path to the mapper you want to load, overrides `input_dir`
35
- mapper_checkpoint_path: Optional[Path] = None
36
- # Specific path to the embeddings you want to load, overrides `input_dir`
37
- learned_embeds_path: Optional[Path] = None
38
- # List of prompts to run inference on
39
- prompts: Optional[List[str]] = None
40
- # Text file containing a prompts to run inference on (one prompt per line), overrides `prompts`
41
- prompts_file_path: Optional[Path] = None
42
- # List of random seeds to run on
43
- seeds: List[int] = field(default_factory=lambda: [42])
44
- # If you want to run with dropout at inference time, this specifies the truncation indices for applying dropout.
45
- # None indicates that no dropout will be performed. If a list of indices is provided, will run all indices.
46
- truncation_idxs: Optional[Union[int, List[int]]] = None
47
- # Whether to run with torch.float16 or torch.float32
48
- torch_dtype: str = "fp16"
49
-
50
- def __post_init__(self):
51
- assert bool(self.prompts) != bool(self.prompts_file_path), \
52
- "You must provide either prompts or prompts_file_path, but not both!"
53
- self._set_prompts()
54
- self._set_input_paths()
55
- self.inference_dir.mkdir(exist_ok=True, parents=True)
56
- if type(self.truncation_idxs) == int:
57
- self.truncation_idxs = [self.truncation_idxs]
58
- self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32
59
-
60
- def _set_input_paths(self):
61
- if self.inference_dir is None:
62
- assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir"
63
- self.inference_dir = self.input_dir / f"inference_{self.iteration}"
64
- if self.mapper_checkpoint_path is None:
65
- assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path"
66
- self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt"
67
- if self.learned_embeds_path is None:
68
- assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path"
69
- self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin"
70
-
71
- def _set_prompts(self):
72
- if self.prompts_file_path is not None:
73
- assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!"
74
- self.prompts = self.prompts_file_path.read_text().splitlines()
75
-
76
-
77
- @pyrallis.wrap()
78
- def main(infer_cfg: InferenceConfig):
79
- train_cfg, mapper = CheckpointHandler.load_mapper(infer_cfg.mapper_checkpoint_path)
80
- pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model(
81
- pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path,
82
- mapper=mapper,
83
- learned_embeds_path=infer_cfg.learned_embeds_path,
84
- torch_dtype=infer_cfg.torch_dtype
85
- )
86
- prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
87
- text_encoder=pipeline.text_encoder,
88
- timesteps=pipeline.scheduler.timesteps,
89
- unet_layers=constants.UNET_LAYERS,
90
- placeholder_token=placeholder_token,
91
- placeholder_token_id=placeholder_token_id,
92
- torch_dtype=infer_cfg.torch_dtype)
93
- for prompt in infer_cfg.prompts:
94
- output_path = infer_cfg.inference_dir / prompt.format(placeholder_token)
95
- output_path.mkdir(exist_ok=True, parents=True)
96
- for truncation_idx in infer_cfg.truncation_idxs:
97
- print(f"Running with truncation index: {truncation_idx}")
98
- prompt_image = run_inference(prompt=prompt,
99
- pipeline=pipeline,
100
- prompt_manager=prompt_manager,
101
- seeds=infer_cfg.seeds,
102
- output_path=output_path,
103
- num_images_per_prompt=1,
104
- truncation_idx=truncation_idx)
105
- if truncation_idx is not None:
106
- save_name = f"{prompt.format(placeholder_token)}_truncation_{truncation_idx}.png"
107
- else:
108
- save_name = f"{prompt.format(placeholder_token)}.png"
109
- prompt_image.save(infer_cfg.inference_dir / save_name)
110
-
111
-
112
- def run_inference(prompt: str,
113
- pipeline: StableDiffusionPipeline,
114
- prompt_manager: PromptManager,
115
- seeds: List[int],
116
- output_path: Optional[Path] = None,
117
- num_images_per_prompt: int = 1,
118
- truncation_idx: Optional[int] = None) -> Image.Image:
119
- with torch.autocast("cuda"):
120
- with torch.no_grad():
121
- prompt_embeds = prompt_manager.embed_prompt(prompt,
122
- num_images_per_prompt=num_images_per_prompt,
123
- truncation_idx=truncation_idx)
124
- joined_images = []
125
- for seed in seeds:
126
- generator = torch.Generator(device='cuda').manual_seed(seed)
127
- images = sd_pipeline_call(pipeline,
128
- prompt_embeds=prompt_embeds,
129
- generator=generator,
130
- num_images_per_prompt=num_images_per_prompt).images
131
- seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB")
132
- if output_path is not None:
133
- save_name = f'{seed}_truncation_{truncation_idx}.png' if truncation_idx is not None else f'{seed}.png'
134
- seed_image.save(output_path / save_name)
135
- joined_images.append(seed_image)
136
- joined_image = vis_utils.get_image_grid(joined_images)
137
- return joined_image
138
-
139
-
140
- def load_stable_diffusion_model(pretrained_model_name_or_path: str,
141
- learned_embeds_path: Path,
142
- mapper: Optional[NeTIMapper] = None,
143
- num_denoising_steps: int = 50,
144
- torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]:
145
- tokenizer = CLIPTokenizer.from_pretrained(
146
- pretrained_model_name_or_path, subfolder="tokenizer")
147
- text_encoder = NeTICLIPTextModel.from_pretrained(
148
- pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
149
- )
150
- if mapper is not None:
151
- text_encoder.text_model.embeddings.set_mapper(mapper)
152
- placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
153
- learned_embeds_path=learned_embeds_path,
154
- text_encoder=text_encoder,
155
- tokenizer=tokenizer
156
- )
157
- pipeline = StableDiffusionPipeline.from_pretrained(
158
- pretrained_model_name_or_path,
159
- torch_dtype=torch_dtype,
160
- text_encoder=text_encoder,
161
- tokenizer=tokenizer
162
- ).to("cuda")
163
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
164
- pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
165
- pipeline.unet.set_attn_processor(XTIAttenProc())
166
- return pipeline, placeholder_token, placeholder_token_id
167
-
168
-
169
- if __name__ == '__main__':
170
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sd_pipeline_call.py DELETED
@@ -1,146 +0,0 @@
1
- from typing import Any, Callable, Dict, List, Optional, Union
2
-
3
- import torch
4
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
5
-
6
-
7
- @torch.no_grad()
8
- def sd_pipeline_call(
9
- pipeline: StableDiffusionPipeline,
10
- prompt_embeds: torch.FloatTensor,
11
- height: Optional[int] = None,
12
- width: Optional[int] = None,
13
- num_inference_steps: int = 50,
14
- guidance_scale: float = 7.5,
15
- negative_prompt: Optional[Union[str, List[str]]] = None,
16
- num_images_per_prompt: Optional[int] = 1,
17
- eta: float = 0.0,
18
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
19
- latents: Optional[torch.FloatTensor] = None,
20
- output_type: Optional[str] = "pil",
21
- return_dict: bool = True,
22
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
23
- callback_steps: int = 1,
24
- cross_attention_kwargs: Optional[Dict[str, Any]] = None):
25
- """ Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
26
-
27
- # 0. Default height and width to unet
28
- height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
29
- width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
30
-
31
- # 2. Define call parameters
32
- batch_size = 1
33
- device = pipeline._execution_device
34
-
35
- neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
36
- negative_prompt_embeds, _ = pipeline.text_encoder(
37
- input_ids=neg_prompt.input_ids.to(device),
38
- attention_mask=None,
39
- )
40
- negative_prompt_embeds = negative_prompt_embeds[0]
41
-
42
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
43
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
44
- # corresponds to doing no classifier free guidance.
45
- do_classifier_free_guidance = guidance_scale > 1.0
46
-
47
- # 4. Prepare timesteps
48
- pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
49
- timesteps = pipeline.scheduler.timesteps
50
-
51
- # 5. Prepare latent variables
52
- num_channels_latents = pipeline.unet.in_channels
53
- latents = pipeline.prepare_latents(
54
- batch_size * num_images_per_prompt,
55
- num_channels_latents,
56
- height,
57
- width,
58
- pipeline.text_encoder.dtype,
59
- device,
60
- generator,
61
- latents,
62
- )
63
-
64
- # 6. Prepare extra step kwargs.
65
- extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
66
-
67
- # 7. Denoising loop
68
- num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
69
- with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
70
- for i, t in enumerate(timesteps):
71
-
72
- if do_classifier_free_guidance:
73
- latent_model_input = latents
74
- latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
75
-
76
- # predict the noise residual
77
- noise_pred_uncond = pipeline.unet(
78
- latent_model_input,
79
- t,
80
- encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
81
- cross_attention_kwargs=cross_attention_kwargs,
82
- ).sample
83
-
84
- ###############################################################
85
- # NeTI logic: use the prompt embedding for the current timestep
86
- ###############################################################
87
- embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
88
- noise_pred_text = pipeline.unet(
89
- latent_model_input,
90
- t,
91
- encoder_hidden_states=embed,
92
- cross_attention_kwargs=cross_attention_kwargs,
93
- ).sample
94
-
95
- # perform guidance
96
- if do_classifier_free_guidance:
97
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
98
-
99
- # compute the previous noisy sample x_t -> x_t-1
100
- latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
101
-
102
- # call the callback, if provided
103
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
104
- progress_bar.update()
105
- if callback is not None and i % callback_steps == 0:
106
- callback(i, t, latents)
107
-
108
- if output_type == "latent":
109
- image = latents
110
- has_nsfw_concept = None
111
- elif output_type == "pil":
112
- # 8. Post-processing
113
- image = pipeline.decode_latents(latents)
114
- # 9. Run safety checker
115
- image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
116
- # 10. Convert to PIL
117
- image = pipeline.numpy_to_pil(image)
118
- else:
119
- # 8. Post-processing
120
- image = pipeline.decode_latents(latents)
121
- # 9. Run safety checker
122
- image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
123
-
124
- # Offload last model to CPU
125
- if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
126
- pipeline.final_offload_hook.offload()
127
-
128
- if not return_dict:
129
- return image, has_nsfw_concept
130
-
131
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
132
-
133
-
134
- def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
135
- negative_prompt: Optional[Union[str, List[str]]] = None):
136
- if negative_prompt is None:
137
- negative_prompt = ""
138
- uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
139
- uncond_input = pipeline.tokenizer(
140
- uncond_tokens,
141
- padding="max_length",
142
- max_length=pipeline.tokenizer.model_max_length,
143
- truncation=True,
144
- return_tensors="pt",
145
- )
146
- return uncond_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py DELETED
File without changes
src/utils/types.py DELETED
@@ -1,20 +0,0 @@
1
- import enum
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import torch
6
-
7
-
8
- @dataclass
9
- class NeTIBatch:
10
- input_ids: torch.Tensor
11
- placeholder_token_id: int
12
- timesteps: torch.Tensor
13
- unet_layers: torch.Tensor
14
- truncation_idx: Optional[int] = None
15
-
16
-
17
- @dataclass
18
- class PESigmas:
19
- sigma_t: float
20
- sigma_l: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/vis_utils.py DELETED
@@ -1,17 +0,0 @@
1
- import math
2
- from typing import List
3
-
4
- from PIL import Image
5
-
6
-
7
- def get_image_grid(images: List[Image.Image]) -> Image:
8
- num_images = len(images)
9
- cols = int(math.ceil(math.sqrt(num_images)))
10
- rows = int(math.ceil(num_images / cols))
11
- width, height = images[0].size
12
- grid_image = Image.new('RGB', (cols * width, rows * height))
13
- for i, img in enumerate(images):
14
- x = i % cols
15
- y = i // cols
16
- grid_image.paste(img, (x * width, y * height))
17
- return grid_image