Create generation_sdxl.py
Browse files- generation_sdxl.py +474 -0
generation_sdxl.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
# Diffusion utils
|
8 |
+
# ------------------------------------------------------------------------
|
9 |
+
def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
|
10 |
+
prompt_embeds_list = []
|
11 |
+
|
12 |
+
captions = []
|
13 |
+
for caption in prompt_batch:
|
14 |
+
if random.random() < proportion_empty_prompts:
|
15 |
+
captions.append("")
|
16 |
+
elif isinstance(caption, str):
|
17 |
+
captions.append(caption)
|
18 |
+
elif isinstance(caption, (list, np.ndarray)):
|
19 |
+
# take a random caption if there are multiple
|
20 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
24 |
+
text_inputs = tokenizer(
|
25 |
+
captions,
|
26 |
+
padding="max_length",
|
27 |
+
max_length=tokenizer.model_max_length,
|
28 |
+
truncation=True,
|
29 |
+
return_tensors="pt",
|
30 |
+
)
|
31 |
+
text_input_ids = text_inputs.input_ids
|
32 |
+
prompt_embeds = text_encoder(
|
33 |
+
text_input_ids.to(text_encoder.device),
|
34 |
+
output_hidden_states=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
38 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
39 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
40 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
41 |
+
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
42 |
+
prompt_embeds_list.append(prompt_embeds)
|
43 |
+
|
44 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
45 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
46 |
+
return prompt_embeds, pooled_prompt_embeds
|
47 |
+
|
48 |
+
|
49 |
+
def compute_embeddings(
|
50 |
+
prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True,
|
51 |
+
device='cuda'
|
52 |
+
):
|
53 |
+
target_size = (1024, 1024)
|
54 |
+
original_sizes = original_sizes #list(map(list, zip(*original_sizes)))
|
55 |
+
crops_coords_top_left = crop_coords #list(map(list, zip(*crop_coords)))
|
56 |
+
|
57 |
+
original_sizes = torch.tensor(original_sizes, dtype=torch.long)
|
58 |
+
crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
|
59 |
+
|
60 |
+
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
61 |
+
prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
|
62 |
+
)
|
63 |
+
add_text_embeds = pooled_prompt_embeds
|
64 |
+
|
65 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
66 |
+
add_time_ids = list(target_size)
|
67 |
+
add_time_ids = torch.tensor([add_time_ids])
|
68 |
+
add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
|
69 |
+
add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
|
70 |
+
add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype)
|
71 |
+
|
72 |
+
prompt_embeds = prompt_embeds.to(device)
|
73 |
+
add_text_embeds = add_text_embeds.to(device)
|
74 |
+
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
75 |
+
|
76 |
+
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
77 |
+
|
78 |
+
def extract_into_tensor(a, t, x_shape):
|
79 |
+
b, *_ = t.shape
|
80 |
+
out = a.gather(-1, t)
|
81 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
82 |
+
|
83 |
+
|
84 |
+
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
|
85 |
+
"""
|
86 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
87 |
+
|
88 |
+
Args:
|
89 |
+
timesteps (`torch.Tensor`):
|
90 |
+
generate embedding vectors at these timesteps
|
91 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
92 |
+
dimension of the embeddings to generate
|
93 |
+
dtype:
|
94 |
+
data type of the generated embeddings
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
98 |
+
"""
|
99 |
+
assert len(w.shape) == 1
|
100 |
+
w = w * 1000.0
|
101 |
+
|
102 |
+
half_dim = embedding_dim // 2
|
103 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
104 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
105 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
106 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
107 |
+
if embedding_dim % 2 == 1: # zero pad
|
108 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
109 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
110 |
+
return emb
|
111 |
+
|
112 |
+
def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas):
|
113 |
+
sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape)
|
114 |
+
alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape)
|
115 |
+
|
116 |
+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
117 |
+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
118 |
+
|
119 |
+
# Set hard boundaries to ensure equivalence with forward (direct) CD
|
120 |
+
alphas_s[boundary_timesteps == 0] = 1.0
|
121 |
+
sigmas_s[boundary_timesteps == 0] = 0.0
|
122 |
+
|
123 |
+
if prediction_type == "epsilon":
|
124 |
+
pred_x_0 = (sample - sigmas * model_output) / alphas # x0 prediction
|
125 |
+
pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output # Euler step to the boundary step
|
126 |
+
elif prediction_type == "v_prediction":
|
127 |
+
assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment"
|
128 |
+
pred_x_0 = alphas * sample - sigmas * model_output
|
129 |
+
else:
|
130 |
+
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
131 |
+
|
132 |
+
return pred_x_0
|
133 |
+
|
134 |
+
|
135 |
+
class DDIMSolver:
|
136 |
+
def __init__(
|
137 |
+
self, alpha_cumprods, timesteps=1000, ddim_timesteps=50,
|
138 |
+
num_endpoints=1, num_inverse_endpoints=1,
|
139 |
+
max_inverse_timestep_index=49,
|
140 |
+
endpoints=None, inverse_endpoints=None
|
141 |
+
):
|
142 |
+
# DDIM sampling parameters
|
143 |
+
step_ratio = timesteps // ddim_timesteps
|
144 |
+
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(
|
145 |
+
np.int64) - 1 # [19, ..., 999]
|
146 |
+
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
|
147 |
+
self.ddim_alpha_cumprods_prev = np.asarray(
|
148 |
+
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
|
149 |
+
)
|
150 |
+
self.ddim_alpha_cumprods_next = np.asarray(
|
151 |
+
alpha_cumprods[self.ddim_timesteps[1:]].tolist() + [0.0]
|
152 |
+
)
|
153 |
+
# convert to torch tensors
|
154 |
+
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
|
155 |
+
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
|
156 |
+
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
|
157 |
+
self.ddim_alpha_cumprods_next = torch.from_numpy(self.ddim_alpha_cumprods_next)
|
158 |
+
|
159 |
+
# Set endpoints for direct CTM
|
160 |
+
if endpoints is None:
|
161 |
+
timestep_interval = ddim_timesteps // num_endpoints + int(ddim_timesteps % num_endpoints > 0)
|
162 |
+
endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1
|
163 |
+
self.endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist())
|
164 |
+
else:
|
165 |
+
self.endpoints = torch.tensor([int(endpoint) for endpoint in endpoints.split(',')])
|
166 |
+
assert len(self.endpoints) == num_endpoints
|
167 |
+
|
168 |
+
# Set endpoints for inverse CTM
|
169 |
+
if inverse_endpoints is None:
|
170 |
+
timestep_interval = ddim_timesteps // num_inverse_endpoints + int(
|
171 |
+
ddim_timesteps % num_inverse_endpoints > 0)
|
172 |
+
inverse_endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1
|
173 |
+
inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index])
|
174 |
+
self.inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs]
|
175 |
+
else:
|
176 |
+
self.inverse_endpoints = torch.tensor([int(endpoint) for endpoint in inverse_endpoints.split(',')])
|
177 |
+
assert len(self.inverse_endpoints) == num_inverse_endpoints
|
178 |
+
|
179 |
+
def to(self, device):
|
180 |
+
self.endpoints = self.endpoints.to(device)
|
181 |
+
self.inverse_endpoints = self.inverse_endpoints.to(device)
|
182 |
+
|
183 |
+
self.ddim_timesteps = self.ddim_timesteps.to(device)
|
184 |
+
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
|
185 |
+
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
|
186 |
+
self.ddim_alpha_cumprods_next = self.ddim_alpha_cumprods_next.to(device)
|
187 |
+
return self
|
188 |
+
|
189 |
+
def ddim_step(self, pred_x0, pred_noise, timestep_index):
|
190 |
+
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
|
191 |
+
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
|
192 |
+
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
|
193 |
+
return x_prev
|
194 |
+
|
195 |
+
def inverse_ddim_step(self, pred_x0, pred_noise, timestep_index):
|
196 |
+
alpha_cumprod_next = extract_into_tensor(self.ddim_alpha_cumprods_next, timestep_index, pred_x0.shape)
|
197 |
+
dir_xt = (1.0 - alpha_cumprod_next).sqrt() * pred_noise
|
198 |
+
x_next = alpha_cumprod_next.sqrt() * pred_x0 + dir_xt
|
199 |
+
return x_next
|
200 |
+
# ------------------------------------------------------------------------
|
201 |
+
|
202 |
+
# Distillation specific
|
203 |
+
# ------------------------------------------------------------------------
|
204 |
+
def inverse_sample_deterministic(
|
205 |
+
pipe,
|
206 |
+
images,
|
207 |
+
prompt,
|
208 |
+
generator=None,
|
209 |
+
num_scales=50,
|
210 |
+
num_inference_steps=1,
|
211 |
+
timesteps=None,
|
212 |
+
start_timestep=19,
|
213 |
+
max_inverse_timestep_index=49,
|
214 |
+
return_start_latent=False,
|
215 |
+
guidance_scale=None, # Used only if the student has w_embedding
|
216 |
+
compute_embeddings_fn=None,
|
217 |
+
is_sdxl=False,
|
218 |
+
inverse_endpoints=None,
|
219 |
+
seed=0,
|
220 |
+
):
|
221 |
+
# assert isinstance(pipe, StableDiffusionImg2ImgPipeline), f"Does not support the pipeline {type(pipe)}"
|
222 |
+
|
223 |
+
if prompt is not None and isinstance(prompt, str):
|
224 |
+
batch_size = 1
|
225 |
+
elif prompt is not None and isinstance(prompt, list):
|
226 |
+
batch_size = len(prompt)
|
227 |
+
|
228 |
+
device = pipe._execution_device
|
229 |
+
|
230 |
+
# Prepare text embeddings
|
231 |
+
if compute_embeddings_fn is not None:
|
232 |
+
if is_sdxl:
|
233 |
+
orig_size = [(1024, 1024)] * len(prompt)
|
234 |
+
crop_coords = [(0, 0)] * len(prompt)
|
235 |
+
encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords)
|
236 |
+
prompt_embeds = encoded_text.pop("prompt_embeds")
|
237 |
+
else:
|
238 |
+
prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"]
|
239 |
+
encoded_text = {}
|
240 |
+
prompt_embeds = prompt_embeds.to(pipe.unet.dtype)
|
241 |
+
else:
|
242 |
+
prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0]
|
243 |
+
encoded_text = {}
|
244 |
+
assert prompt_embeds.dtype == pipe.unet.dtype
|
245 |
+
|
246 |
+
# Prepare the DDIM solver
|
247 |
+
endpoints = ','.join(['0'] + inverse_endpoints.split(',')[:-1]) if inverse_endpoints is not None else None
|
248 |
+
solver = DDIMSolver(
|
249 |
+
pipe.scheduler.alphas_cumprod.cpu().numpy(),
|
250 |
+
timesteps=pipe.scheduler.num_train_timesteps,
|
251 |
+
ddim_timesteps=num_scales,
|
252 |
+
num_endpoints=num_inference_steps,
|
253 |
+
num_inverse_endpoints=num_inference_steps,
|
254 |
+
max_inverse_timestep_index=max_inverse_timestep_index,
|
255 |
+
endpoints=endpoints,
|
256 |
+
inverse_endpoints=inverse_endpoints
|
257 |
+
).to(device)
|
258 |
+
|
259 |
+
if timesteps is None:
|
260 |
+
timesteps = solver.inverse_endpoints.flip(0)
|
261 |
+
boundary_timesteps = solver.endpoints.flip(0)
|
262 |
+
else:
|
263 |
+
timesteps, boundary_timesteps = timesteps, timesteps
|
264 |
+
boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]]
|
265 |
+
boundary_timesteps[-1] = 999
|
266 |
+
timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps)
|
267 |
+
|
268 |
+
alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device)
|
269 |
+
sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device)
|
270 |
+
|
271 |
+
# 5. Prepare latent variables
|
272 |
+
num_channels_latents = pipe.unet.config.in_channels
|
273 |
+
start_latents = pipe.prepare_latents(
|
274 |
+
images, timesteps[0], batch_size, 1, prompt_embeds.dtype, device,
|
275 |
+
generator=torch.Generator().manual_seed(seed),
|
276 |
+
)
|
277 |
+
latents = start_latents.clone()
|
278 |
+
|
279 |
+
if guidance_scale is not None:
|
280 |
+
w = torch.ones(batch_size) * guidance_scale
|
281 |
+
w_embedding = guidance_scale_embedding(w, embedding_dim=512)
|
282 |
+
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
283 |
+
else:
|
284 |
+
w_embedding = None
|
285 |
+
|
286 |
+
for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)):
|
287 |
+
# predict the noise residual
|
288 |
+
noise_pred = pipe.unet(
|
289 |
+
latents.to(prompt_embeds.dtype),
|
290 |
+
t,
|
291 |
+
encoder_hidden_states=prompt_embeds,
|
292 |
+
return_dict=False,
|
293 |
+
timestep_cond=w_embedding,
|
294 |
+
added_cond_kwargs=encoded_text,
|
295 |
+
)[0]
|
296 |
+
|
297 |
+
latents = predicted_origin(
|
298 |
+
noise_pred,
|
299 |
+
torch.tensor([t] * len(latents), device=device),
|
300 |
+
torch.tensor([s] * len(latents), device=device),
|
301 |
+
latents,
|
302 |
+
pipe.scheduler.config.prediction_type,
|
303 |
+
alpha_schedule,
|
304 |
+
sigma_schedule,
|
305 |
+
).to(prompt_embeds.dtype)
|
306 |
+
|
307 |
+
if return_start_latent:
|
308 |
+
return latents, start_latents
|
309 |
+
else:
|
310 |
+
return latents
|
311 |
+
|
312 |
+
|
313 |
+
def linear_schedule_old(t, guidance_scale, tau1, tau2):
|
314 |
+
t = t / 1000
|
315 |
+
if t <= tau1:
|
316 |
+
gamma = 1.0
|
317 |
+
elif t >= tau2:
|
318 |
+
gamma = 0.0
|
319 |
+
else:
|
320 |
+
gamma = (tau2 - t) / (tau2 - tau1)
|
321 |
+
return gamma * guidance_scale
|
322 |
+
|
323 |
+
|
324 |
+
@torch.no_grad()
|
325 |
+
def sample_deterministic(
|
326 |
+
pipe,
|
327 |
+
prompt,
|
328 |
+
latents=None,
|
329 |
+
generator=None,
|
330 |
+
num_scales=50,
|
331 |
+
num_inference_steps=1,
|
332 |
+
timesteps=None,
|
333 |
+
start_timestep=19,
|
334 |
+
max_inverse_timestep_index=49,
|
335 |
+
return_latent=False,
|
336 |
+
guidance_scale=None, # Used only if the student has w_embedding
|
337 |
+
compute_embeddings_fn=None,
|
338 |
+
is_sdxl=False,
|
339 |
+
endpoints=None,
|
340 |
+
use_dynamic_guidance=False,
|
341 |
+
tau1=0.7,
|
342 |
+
tau2=0.7,
|
343 |
+
amplify_prompt=None,
|
344 |
+
):
|
345 |
+
# assert isinstance(pipe, StableDiffusionPipeline), f"Does not support the pipeline {type(pipe)}"
|
346 |
+
height = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
347 |
+
width = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
348 |
+
|
349 |
+
# 1. Define call parameters
|
350 |
+
if prompt is not None and isinstance(prompt, str):
|
351 |
+
batch_size = 1
|
352 |
+
elif prompt is not None and isinstance(prompt, list):
|
353 |
+
batch_size = len(prompt)
|
354 |
+
|
355 |
+
device = pipe._execution_device
|
356 |
+
|
357 |
+
# Prepare text embeddings
|
358 |
+
if compute_embeddings_fn is not None:
|
359 |
+
if is_sdxl:
|
360 |
+
orig_size = [(1024, 1024)] * len(prompt)
|
361 |
+
crop_coords = [(0, 0)] * len(prompt)
|
362 |
+
encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords)
|
363 |
+
prompt_embeds = encoded_text.pop("prompt_embeds")
|
364 |
+
if amplify_prompt is not None:
|
365 |
+
orig_size = [(1024, 1024)] * len(amplify_prompt)
|
366 |
+
crop_coords = [(0, 0)] * len(amplify_prompt)
|
367 |
+
encoded_text_old = compute_embeddings_fn(amplify_prompt, orig_size, crop_coords)
|
368 |
+
amplify_prompt_embeds = encoded_text_old.pop("prompt_embeds")
|
369 |
+
else:
|
370 |
+
prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"]
|
371 |
+
encoded_text = {}
|
372 |
+
prompt_embeds = prompt_embeds.to(pipe.unet.dtype)
|
373 |
+
else:
|
374 |
+
prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0]
|
375 |
+
encoded_text = {}
|
376 |
+
assert prompt_embeds.dtype == pipe.unet.dtype
|
377 |
+
|
378 |
+
# Prepare the DDIM solver
|
379 |
+
inverse_endpoints = ','.join(endpoints.split(',')[1:] + ['999']) if endpoints is not None else None
|
380 |
+
solver = DDIMSolver(
|
381 |
+
pipe.scheduler.alphas_cumprod.numpy(),
|
382 |
+
timesteps=pipe.scheduler.num_train_timesteps,
|
383 |
+
ddim_timesteps=num_scales,
|
384 |
+
num_endpoints=num_inference_steps,
|
385 |
+
num_inverse_endpoints=num_inference_steps,
|
386 |
+
max_inverse_timestep_index=max_inverse_timestep_index,
|
387 |
+
endpoints=endpoints,
|
388 |
+
inverse_endpoints=inverse_endpoints
|
389 |
+
).to(device)
|
390 |
+
|
391 |
+
prompt_embeds_init = copy.deepcopy(prompt_embeds)
|
392 |
+
|
393 |
+
if timesteps is None:
|
394 |
+
timesteps = solver.inverse_endpoints.flip(0)
|
395 |
+
boundary_timesteps = solver.endpoints.flip(0)
|
396 |
+
else:
|
397 |
+
timesteps, boundary_timesteps = copy.deepcopy(timesteps), copy.deepcopy(timesteps)
|
398 |
+
timesteps.reverse()
|
399 |
+
boundary_timesteps.reverse()
|
400 |
+
boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]]
|
401 |
+
boundary_timesteps[-1] = 0
|
402 |
+
timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps)
|
403 |
+
|
404 |
+
alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device)
|
405 |
+
sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device)
|
406 |
+
|
407 |
+
# 5. Prepare latent variables
|
408 |
+
if latents is None:
|
409 |
+
num_channels_latents = pipe.unet.config.in_channels
|
410 |
+
latents = pipe.prepare_latents(
|
411 |
+
batch_size,
|
412 |
+
num_channels_latents,
|
413 |
+
height,
|
414 |
+
width,
|
415 |
+
prompt_embeds.dtype,
|
416 |
+
device,
|
417 |
+
generator,
|
418 |
+
None,
|
419 |
+
)
|
420 |
+
assert latents.dtype == pipe.unet.dtype
|
421 |
+
else:
|
422 |
+
latents = latents.to(prompt_embeds.dtype)
|
423 |
+
|
424 |
+
if guidance_scale is not None:
|
425 |
+
w = torch.ones(batch_size) * guidance_scale
|
426 |
+
w_embedding = guidance_scale_embedding(w, embedding_dim=512)
|
427 |
+
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
428 |
+
else:
|
429 |
+
w_embedding = None
|
430 |
+
|
431 |
+
for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)):
|
432 |
+
if use_dynamic_guidance:
|
433 |
+
if not isinstance(t, int):
|
434 |
+
t_item = t.item()
|
435 |
+
if t_item > tau1 * 1000 and amplify_prompt is not None:
|
436 |
+
prompt_embeds = amplify_prompt_embeds
|
437 |
+
else:
|
438 |
+
prompt_embeds = prompt_embeds_init
|
439 |
+
guidance_scale = linear_schedule_old(t_item, w, tau1=tau1, tau2=tau2)
|
440 |
+
guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents))
|
441 |
+
w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=512)
|
442 |
+
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
443 |
+
|
444 |
+
# predict the noise residual
|
445 |
+
noise_pred = pipe.unet(
|
446 |
+
latents,
|
447 |
+
t,
|
448 |
+
encoder_hidden_states=prompt_embeds,
|
449 |
+
cross_attention_kwargs=None,
|
450 |
+
return_dict=False,
|
451 |
+
timestep_cond=w_embedding,
|
452 |
+
added_cond_kwargs=encoded_text,
|
453 |
+
)[0]
|
454 |
+
|
455 |
+
latents = predicted_origin(
|
456 |
+
noise_pred,
|
457 |
+
torch.tensor([t] * len(noise_pred)).to(device),
|
458 |
+
torch.tensor([s] * len(noise_pred)).to(device),
|
459 |
+
latents,
|
460 |
+
pipe.scheduler.config.prediction_type,
|
461 |
+
alpha_schedule,
|
462 |
+
sigma_schedule,
|
463 |
+
).to(pipe.unet.dtype)
|
464 |
+
|
465 |
+
pipe.vae.to(torch.float32)
|
466 |
+
image = pipe.vae.decode(latents.to(torch.float32) / pipe.vae.config.scaling_factor, return_dict=False)[0]
|
467 |
+
do_denormalize = [True] * image.shape[0]
|
468 |
+
image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
|
469 |
+
|
470 |
+
if return_latent:
|
471 |
+
return image, latents
|
472 |
+
else:
|
473 |
+
return image
|
474 |
+
# ------------------------------------------------------------------------
|