quickjkee commited on
Commit
6ec756e
1 Parent(s): 587df09

Create generation_sdxl.py

Browse files
Files changed (1) hide show
  1. generation_sdxl.py +474 -0
generation_sdxl.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import random
4
+ import numpy as np
5
+
6
+
7
+ # Diffusion utils
8
+ # ------------------------------------------------------------------------
9
+ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
10
+ prompt_embeds_list = []
11
+
12
+ captions = []
13
+ for caption in prompt_batch:
14
+ if random.random() < proportion_empty_prompts:
15
+ captions.append("")
16
+ elif isinstance(caption, str):
17
+ captions.append(caption)
18
+ elif isinstance(caption, (list, np.ndarray)):
19
+ # take a random caption if there are multiple
20
+ captions.append(random.choice(caption) if is_train else caption[0])
21
+
22
+ with torch.no_grad():
23
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
24
+ text_inputs = tokenizer(
25
+ captions,
26
+ padding="max_length",
27
+ max_length=tokenizer.model_max_length,
28
+ truncation=True,
29
+ return_tensors="pt",
30
+ )
31
+ text_input_ids = text_inputs.input_ids
32
+ prompt_embeds = text_encoder(
33
+ text_input_ids.to(text_encoder.device),
34
+ output_hidden_states=True,
35
+ )
36
+
37
+ # We are only ALWAYS interested in the pooled output of the final text encoder
38
+ pooled_prompt_embeds = prompt_embeds[0]
39
+ prompt_embeds = prompt_embeds.hidden_states[-2]
40
+ bs_embed, seq_len, _ = prompt_embeds.shape
41
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
42
+ prompt_embeds_list.append(prompt_embeds)
43
+
44
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
45
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
46
+ return prompt_embeds, pooled_prompt_embeds
47
+
48
+
49
+ def compute_embeddings(
50
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True,
51
+ device='cuda'
52
+ ):
53
+ target_size = (1024, 1024)
54
+ original_sizes = original_sizes #list(map(list, zip(*original_sizes)))
55
+ crops_coords_top_left = crop_coords #list(map(list, zip(*crop_coords)))
56
+
57
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
58
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
59
+
60
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
61
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
62
+ )
63
+ add_text_embeds = pooled_prompt_embeds
64
+
65
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
66
+ add_time_ids = list(target_size)
67
+ add_time_ids = torch.tensor([add_time_ids])
68
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
69
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
70
+ add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype)
71
+
72
+ prompt_embeds = prompt_embeds.to(device)
73
+ add_text_embeds = add_text_embeds.to(device)
74
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
75
+
76
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
77
+
78
+ def extract_into_tensor(a, t, x_shape):
79
+ b, *_ = t.shape
80
+ out = a.gather(-1, t)
81
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
82
+
83
+
84
+ def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
85
+ """
86
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
87
+
88
+ Args:
89
+ timesteps (`torch.Tensor`):
90
+ generate embedding vectors at these timesteps
91
+ embedding_dim (`int`, *optional*, defaults to 512):
92
+ dimension of the embeddings to generate
93
+ dtype:
94
+ data type of the generated embeddings
95
+
96
+ Returns:
97
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
98
+ """
99
+ assert len(w.shape) == 1
100
+ w = w * 1000.0
101
+
102
+ half_dim = embedding_dim // 2
103
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
104
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
105
+ emb = w.to(dtype)[:, None] * emb[None, :]
106
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
107
+ if embedding_dim % 2 == 1: # zero pad
108
+ emb = torch.nn.functional.pad(emb, (0, 1))
109
+ assert emb.shape == (w.shape[0], embedding_dim)
110
+ return emb
111
+
112
+ def predicted_origin(model_output, timesteps, boundary_timesteps, sample, prediction_type, alphas, sigmas):
113
+ sigmas_s = extract_into_tensor(sigmas, boundary_timesteps, sample.shape)
114
+ alphas_s = extract_into_tensor(alphas, boundary_timesteps, sample.shape)
115
+
116
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
117
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
118
+
119
+ # Set hard boundaries to ensure equivalence with forward (direct) CD
120
+ alphas_s[boundary_timesteps == 0] = 1.0
121
+ sigmas_s[boundary_timesteps == 0] = 0.0
122
+
123
+ if prediction_type == "epsilon":
124
+ pred_x_0 = (sample - sigmas * model_output) / alphas # x0 prediction
125
+ pred_x_0 = alphas_s * pred_x_0 + sigmas_s * model_output # Euler step to the boundary step
126
+ elif prediction_type == "v_prediction":
127
+ assert boundary_timesteps == 0, "v_prediction does not support multiple endpoints at the moment"
128
+ pred_x_0 = alphas * sample - sigmas * model_output
129
+ else:
130
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
131
+
132
+ return pred_x_0
133
+
134
+
135
+ class DDIMSolver:
136
+ def __init__(
137
+ self, alpha_cumprods, timesteps=1000, ddim_timesteps=50,
138
+ num_endpoints=1, num_inverse_endpoints=1,
139
+ max_inverse_timestep_index=49,
140
+ endpoints=None, inverse_endpoints=None
141
+ ):
142
+ # DDIM sampling parameters
143
+ step_ratio = timesteps // ddim_timesteps
144
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(
145
+ np.int64) - 1 # [19, ..., 999]
146
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
147
+ self.ddim_alpha_cumprods_prev = np.asarray(
148
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
149
+ )
150
+ self.ddim_alpha_cumprods_next = np.asarray(
151
+ alpha_cumprods[self.ddim_timesteps[1:]].tolist() + [0.0]
152
+ )
153
+ # convert to torch tensors
154
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
155
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
156
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
157
+ self.ddim_alpha_cumprods_next = torch.from_numpy(self.ddim_alpha_cumprods_next)
158
+
159
+ # Set endpoints for direct CTM
160
+ if endpoints is None:
161
+ timestep_interval = ddim_timesteps // num_endpoints + int(ddim_timesteps % num_endpoints > 0)
162
+ endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1
163
+ self.endpoints = torch.tensor([0] + self.ddim_timesteps[endpoint_idxs].tolist())
164
+ else:
165
+ self.endpoints = torch.tensor([int(endpoint) for endpoint in endpoints.split(',')])
166
+ assert len(self.endpoints) == num_endpoints
167
+
168
+ # Set endpoints for inverse CTM
169
+ if inverse_endpoints is None:
170
+ timestep_interval = ddim_timesteps // num_inverse_endpoints + int(
171
+ ddim_timesteps % num_inverse_endpoints > 0)
172
+ inverse_endpoint_idxs = torch.arange(timestep_interval, ddim_timesteps, timestep_interval) - 1
173
+ inverse_endpoint_idxs = torch.tensor(inverse_endpoint_idxs.tolist() + [max_inverse_timestep_index])
174
+ self.inverse_endpoints = self.ddim_timesteps[inverse_endpoint_idxs]
175
+ else:
176
+ self.inverse_endpoints = torch.tensor([int(endpoint) for endpoint in inverse_endpoints.split(',')])
177
+ assert len(self.inverse_endpoints) == num_inverse_endpoints
178
+
179
+ def to(self, device):
180
+ self.endpoints = self.endpoints.to(device)
181
+ self.inverse_endpoints = self.inverse_endpoints.to(device)
182
+
183
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
184
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
185
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
186
+ self.ddim_alpha_cumprods_next = self.ddim_alpha_cumprods_next.to(device)
187
+ return self
188
+
189
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
190
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
191
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
192
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
193
+ return x_prev
194
+
195
+ def inverse_ddim_step(self, pred_x0, pred_noise, timestep_index):
196
+ alpha_cumprod_next = extract_into_tensor(self.ddim_alpha_cumprods_next, timestep_index, pred_x0.shape)
197
+ dir_xt = (1.0 - alpha_cumprod_next).sqrt() * pred_noise
198
+ x_next = alpha_cumprod_next.sqrt() * pred_x0 + dir_xt
199
+ return x_next
200
+ # ------------------------------------------------------------------------
201
+
202
+ # Distillation specific
203
+ # ------------------------------------------------------------------------
204
+ def inverse_sample_deterministic(
205
+ pipe,
206
+ images,
207
+ prompt,
208
+ generator=None,
209
+ num_scales=50,
210
+ num_inference_steps=1,
211
+ timesteps=None,
212
+ start_timestep=19,
213
+ max_inverse_timestep_index=49,
214
+ return_start_latent=False,
215
+ guidance_scale=None, # Used only if the student has w_embedding
216
+ compute_embeddings_fn=None,
217
+ is_sdxl=False,
218
+ inverse_endpoints=None,
219
+ seed=0,
220
+ ):
221
+ # assert isinstance(pipe, StableDiffusionImg2ImgPipeline), f"Does not support the pipeline {type(pipe)}"
222
+
223
+ if prompt is not None and isinstance(prompt, str):
224
+ batch_size = 1
225
+ elif prompt is not None and isinstance(prompt, list):
226
+ batch_size = len(prompt)
227
+
228
+ device = pipe._execution_device
229
+
230
+ # Prepare text embeddings
231
+ if compute_embeddings_fn is not None:
232
+ if is_sdxl:
233
+ orig_size = [(1024, 1024)] * len(prompt)
234
+ crop_coords = [(0, 0)] * len(prompt)
235
+ encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords)
236
+ prompt_embeds = encoded_text.pop("prompt_embeds")
237
+ else:
238
+ prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"]
239
+ encoded_text = {}
240
+ prompt_embeds = prompt_embeds.to(pipe.unet.dtype)
241
+ else:
242
+ prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0]
243
+ encoded_text = {}
244
+ assert prompt_embeds.dtype == pipe.unet.dtype
245
+
246
+ # Prepare the DDIM solver
247
+ endpoints = ','.join(['0'] + inverse_endpoints.split(',')[:-1]) if inverse_endpoints is not None else None
248
+ solver = DDIMSolver(
249
+ pipe.scheduler.alphas_cumprod.cpu().numpy(),
250
+ timesteps=pipe.scheduler.num_train_timesteps,
251
+ ddim_timesteps=num_scales,
252
+ num_endpoints=num_inference_steps,
253
+ num_inverse_endpoints=num_inference_steps,
254
+ max_inverse_timestep_index=max_inverse_timestep_index,
255
+ endpoints=endpoints,
256
+ inverse_endpoints=inverse_endpoints
257
+ ).to(device)
258
+
259
+ if timesteps is None:
260
+ timesteps = solver.inverse_endpoints.flip(0)
261
+ boundary_timesteps = solver.endpoints.flip(0)
262
+ else:
263
+ timesteps, boundary_timesteps = timesteps, timesteps
264
+ boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]]
265
+ boundary_timesteps[-1] = 999
266
+ timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps)
267
+
268
+ alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device)
269
+ sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device)
270
+
271
+ # 5. Prepare latent variables
272
+ num_channels_latents = pipe.unet.config.in_channels
273
+ start_latents = pipe.prepare_latents(
274
+ images, timesteps[0], batch_size, 1, prompt_embeds.dtype, device,
275
+ generator=torch.Generator().manual_seed(seed),
276
+ )
277
+ latents = start_latents.clone()
278
+
279
+ if guidance_scale is not None:
280
+ w = torch.ones(batch_size) * guidance_scale
281
+ w_embedding = guidance_scale_embedding(w, embedding_dim=512)
282
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
283
+ else:
284
+ w_embedding = None
285
+
286
+ for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)):
287
+ # predict the noise residual
288
+ noise_pred = pipe.unet(
289
+ latents.to(prompt_embeds.dtype),
290
+ t,
291
+ encoder_hidden_states=prompt_embeds,
292
+ return_dict=False,
293
+ timestep_cond=w_embedding,
294
+ added_cond_kwargs=encoded_text,
295
+ )[0]
296
+
297
+ latents = predicted_origin(
298
+ noise_pred,
299
+ torch.tensor([t] * len(latents), device=device),
300
+ torch.tensor([s] * len(latents), device=device),
301
+ latents,
302
+ pipe.scheduler.config.prediction_type,
303
+ alpha_schedule,
304
+ sigma_schedule,
305
+ ).to(prompt_embeds.dtype)
306
+
307
+ if return_start_latent:
308
+ return latents, start_latents
309
+ else:
310
+ return latents
311
+
312
+
313
+ def linear_schedule_old(t, guidance_scale, tau1, tau2):
314
+ t = t / 1000
315
+ if t <= tau1:
316
+ gamma = 1.0
317
+ elif t >= tau2:
318
+ gamma = 0.0
319
+ else:
320
+ gamma = (tau2 - t) / (tau2 - tau1)
321
+ return gamma * guidance_scale
322
+
323
+
324
+ @torch.no_grad()
325
+ def sample_deterministic(
326
+ pipe,
327
+ prompt,
328
+ latents=None,
329
+ generator=None,
330
+ num_scales=50,
331
+ num_inference_steps=1,
332
+ timesteps=None,
333
+ start_timestep=19,
334
+ max_inverse_timestep_index=49,
335
+ return_latent=False,
336
+ guidance_scale=None, # Used only if the student has w_embedding
337
+ compute_embeddings_fn=None,
338
+ is_sdxl=False,
339
+ endpoints=None,
340
+ use_dynamic_guidance=False,
341
+ tau1=0.7,
342
+ tau2=0.7,
343
+ amplify_prompt=None,
344
+ ):
345
+ # assert isinstance(pipe, StableDiffusionPipeline), f"Does not support the pipeline {type(pipe)}"
346
+ height = pipe.unet.config.sample_size * pipe.vae_scale_factor
347
+ width = pipe.unet.config.sample_size * pipe.vae_scale_factor
348
+
349
+ # 1. Define call parameters
350
+ if prompt is not None and isinstance(prompt, str):
351
+ batch_size = 1
352
+ elif prompt is not None and isinstance(prompt, list):
353
+ batch_size = len(prompt)
354
+
355
+ device = pipe._execution_device
356
+
357
+ # Prepare text embeddings
358
+ if compute_embeddings_fn is not None:
359
+ if is_sdxl:
360
+ orig_size = [(1024, 1024)] * len(prompt)
361
+ crop_coords = [(0, 0)] * len(prompt)
362
+ encoded_text = compute_embeddings_fn(prompt, orig_size, crop_coords)
363
+ prompt_embeds = encoded_text.pop("prompt_embeds")
364
+ if amplify_prompt is not None:
365
+ orig_size = [(1024, 1024)] * len(amplify_prompt)
366
+ crop_coords = [(0, 0)] * len(amplify_prompt)
367
+ encoded_text_old = compute_embeddings_fn(amplify_prompt, orig_size, crop_coords)
368
+ amplify_prompt_embeds = encoded_text_old.pop("prompt_embeds")
369
+ else:
370
+ prompt_embeds = compute_embeddings_fn(prompt)["prompt_embeds"]
371
+ encoded_text = {}
372
+ prompt_embeds = prompt_embeds.to(pipe.unet.dtype)
373
+ else:
374
+ prompt_embeds = pipe.encode_prompt(prompt, device, 1, False)[0]
375
+ encoded_text = {}
376
+ assert prompt_embeds.dtype == pipe.unet.dtype
377
+
378
+ # Prepare the DDIM solver
379
+ inverse_endpoints = ','.join(endpoints.split(',')[1:] + ['999']) if endpoints is not None else None
380
+ solver = DDIMSolver(
381
+ pipe.scheduler.alphas_cumprod.numpy(),
382
+ timesteps=pipe.scheduler.num_train_timesteps,
383
+ ddim_timesteps=num_scales,
384
+ num_endpoints=num_inference_steps,
385
+ num_inverse_endpoints=num_inference_steps,
386
+ max_inverse_timestep_index=max_inverse_timestep_index,
387
+ endpoints=endpoints,
388
+ inverse_endpoints=inverse_endpoints
389
+ ).to(device)
390
+
391
+ prompt_embeds_init = copy.deepcopy(prompt_embeds)
392
+
393
+ if timesteps is None:
394
+ timesteps = solver.inverse_endpoints.flip(0)
395
+ boundary_timesteps = solver.endpoints.flip(0)
396
+ else:
397
+ timesteps, boundary_timesteps = copy.deepcopy(timesteps), copy.deepcopy(timesteps)
398
+ timesteps.reverse()
399
+ boundary_timesteps.reverse()
400
+ boundary_timesteps = boundary_timesteps[1:] + [boundary_timesteps[0]]
401
+ boundary_timesteps[-1] = 0
402
+ timesteps, boundary_timesteps = torch.tensor(timesteps), torch.tensor(boundary_timesteps)
403
+
404
+ alpha_schedule = torch.sqrt(pipe.scheduler.alphas_cumprod).to(device)
405
+ sigma_schedule = torch.sqrt(1 - pipe.scheduler.alphas_cumprod).to(device)
406
+
407
+ # 5. Prepare latent variables
408
+ if latents is None:
409
+ num_channels_latents = pipe.unet.config.in_channels
410
+ latents = pipe.prepare_latents(
411
+ batch_size,
412
+ num_channels_latents,
413
+ height,
414
+ width,
415
+ prompt_embeds.dtype,
416
+ device,
417
+ generator,
418
+ None,
419
+ )
420
+ assert latents.dtype == pipe.unet.dtype
421
+ else:
422
+ latents = latents.to(prompt_embeds.dtype)
423
+
424
+ if guidance_scale is not None:
425
+ w = torch.ones(batch_size) * guidance_scale
426
+ w_embedding = guidance_scale_embedding(w, embedding_dim=512)
427
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
428
+ else:
429
+ w_embedding = None
430
+
431
+ for i, (t, s) in enumerate(zip(timesteps, boundary_timesteps)):
432
+ if use_dynamic_guidance:
433
+ if not isinstance(t, int):
434
+ t_item = t.item()
435
+ if t_item > tau1 * 1000 and amplify_prompt is not None:
436
+ prompt_embeds = amplify_prompt_embeds
437
+ else:
438
+ prompt_embeds = prompt_embeds_init
439
+ guidance_scale = linear_schedule_old(t_item, w, tau1=tau1, tau2=tau2)
440
+ guidance_scale_tensor = torch.tensor([guidance_scale] * len(latents))
441
+ w_embedding = guidance_scale_embedding(guidance_scale_tensor, embedding_dim=512)
442
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
443
+
444
+ # predict the noise residual
445
+ noise_pred = pipe.unet(
446
+ latents,
447
+ t,
448
+ encoder_hidden_states=prompt_embeds,
449
+ cross_attention_kwargs=None,
450
+ return_dict=False,
451
+ timestep_cond=w_embedding,
452
+ added_cond_kwargs=encoded_text,
453
+ )[0]
454
+
455
+ latents = predicted_origin(
456
+ noise_pred,
457
+ torch.tensor([t] * len(noise_pred)).to(device),
458
+ torch.tensor([s] * len(noise_pred)).to(device),
459
+ latents,
460
+ pipe.scheduler.config.prediction_type,
461
+ alpha_schedule,
462
+ sigma_schedule,
463
+ ).to(pipe.unet.dtype)
464
+
465
+ pipe.vae.to(torch.float32)
466
+ image = pipe.vae.decode(latents.to(torch.float32) / pipe.vae.config.scaling_factor, return_dict=False)[0]
467
+ do_denormalize = [True] * image.shape[0]
468
+ image = pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
469
+
470
+ if return_latent:
471
+ return image, latents
472
+ else:
473
+ return image
474
+ # ------------------------------------------------------------------------