Richard Neuschulz commited on
Commit
b883378
1 Parent(s): cf453dc

own ipadapter copy

Browse files
Files changed (2) hide show
  1. app.py +6 -7
  2. ipown.py +468 -0
app.py CHANGED
@@ -3,7 +3,7 @@ import spaces
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline
4
  from transformers import AutoFeatureExtractor
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
- from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
7
  from huggingface_hub import hf_hub_download
8
  from insightface.app import FaceAnalysis
9
  from insightface.utils import face_align
@@ -60,12 +60,11 @@ def generate_image(images, prompt, negative_prompt, preserve_face_structure, fac
60
 
61
  total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
62
 
63
- if(not preserve_face_structure):
64
- print("Generating normal")
65
- image = ip_model.generate(
66
- prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
67
- scale=likeness_strength, width=512, height=512, num_inference_steps=30
68
- )
69
 
70
  print(image)
71
  return image
 
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline
4
  from transformers import AutoFeatureExtractor
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ import ipown
7
  from huggingface_hub import hf_hub_download
8
  from insightface.app import FaceAnalysis
9
  from insightface.utils import face_align
 
60
 
61
  total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
62
 
63
+ print("Generating normal")
64
+ image = ip_model.generate(
65
+ prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
66
+ scale=likeness_strength, width=1024, height=1024, guidance_scale=7.5, num_inference_steps=30
67
+ )
 
68
 
69
  print(image)
70
  return image
ipown.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
12
+ from .utils import is_torch2_available
13
+
14
+ USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
15
+ if is_torch2_available() and (not USE_DAFAULT_ATTN):
16
+ from .attention_processor_faceid import (
17
+ LoRAAttnProcessor2_0 as LoRAAttnProcessor,
18
+ )
19
+ from .attention_processor_faceid import (
20
+ LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
21
+ )
22
+ else:
23
+ from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
24
+ from .resampler import PerceiverAttention, FeedForward
25
+
26
+
27
+ class FacePerceiverResampler(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ *,
31
+ dim=768,
32
+ depth=4,
33
+ dim_head=64,
34
+ heads=16,
35
+ embedding_dim=1280,
36
+ output_dim=768,
37
+ ff_mult=4,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
42
+ self.proj_out = torch.nn.Linear(dim, output_dim)
43
+ self.norm_out = torch.nn.LayerNorm(output_dim)
44
+ self.layers = torch.nn.ModuleList([])
45
+ for _ in range(depth):
46
+ self.layers.append(
47
+ torch.nn.ModuleList(
48
+ [
49
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
50
+ FeedForward(dim=dim, mult=ff_mult),
51
+ ]
52
+ )
53
+ )
54
+
55
+ def forward(self, latents, x):
56
+ x = self.proj_in(x)
57
+ for attn, ff in self.layers:
58
+ latents = attn(x, latents) + latents
59
+ latents = ff(latents) + latents
60
+ latents = self.proj_out(latents)
61
+ return self.norm_out(latents)
62
+
63
+
64
+ class MLPProjModel(torch.nn.Module):
65
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
66
+ super().__init__()
67
+
68
+ self.cross_attention_dim = cross_attention_dim
69
+ self.num_tokens = num_tokens
70
+
71
+ self.proj = torch.nn.Sequential(
72
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
73
+ torch.nn.GELU(),
74
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
75
+ )
76
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
77
+
78
+ def forward(self, id_embeds):
79
+ x = self.proj(id_embeds)
80
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
81
+ x = self.norm(x)
82
+ return x
83
+
84
+
85
+ class ProjPlusModel(torch.nn.Module):
86
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
87
+ super().__init__()
88
+
89
+ self.cross_attention_dim = cross_attention_dim
90
+ self.num_tokens = num_tokens
91
+
92
+ self.proj = torch.nn.Sequential(
93
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
94
+ torch.nn.GELU(),
95
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
96
+ )
97
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
98
+
99
+ self.perceiver_resampler = FacePerceiverResampler(
100
+ dim=cross_attention_dim,
101
+ depth=4,
102
+ dim_head=64,
103
+ heads=cross_attention_dim // 64,
104
+ embedding_dim=clip_embeddings_dim,
105
+ output_dim=cross_attention_dim,
106
+ ff_mult=4,
107
+ )
108
+
109
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
110
+
111
+ x = self.proj(id_embeds)
112
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
113
+ x = self.norm(x)
114
+ out = self.perceiver_resampler(x, clip_embeds)
115
+ if shortcut:
116
+ out = x + scale * out
117
+ return out
118
+
119
+
120
+ class IPAdapterFaceID:
121
+ def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
122
+ self.device = device
123
+ self.ip_ckpt = ip_ckpt
124
+ self.lora_rank = lora_rank
125
+ self.num_tokens = num_tokens
126
+ self.torch_dtype = torch_dtype
127
+
128
+ self.pipe = sd_pipe.to(self.device)
129
+ self.set_ip_adapter()
130
+
131
+ # image proj model
132
+ self.image_proj_model = self.init_proj()
133
+
134
+ self.load_ip_adapter()
135
+
136
+ def init_proj(self):
137
+ image_proj_model = MLPProjModel(
138
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
139
+ id_embeddings_dim=512,
140
+ num_tokens=self.num_tokens,
141
+ ).to(self.device, dtype=self.torch_dtype)
142
+ return image_proj_model
143
+
144
+ def set_ip_adapter(self):
145
+ unet = self.pipe.unet
146
+ attn_procs = {}
147
+ for name in unet.attn_processors.keys():
148
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
149
+ if name.startswith("mid_block"):
150
+ hidden_size = unet.config.block_out_channels[-1]
151
+ elif name.startswith("up_blocks"):
152
+ block_id = int(name[len("up_blocks.")])
153
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
154
+ elif name.startswith("down_blocks"):
155
+ block_id = int(name[len("down_blocks.")])
156
+ hidden_size = unet.config.block_out_channels[block_id]
157
+ if cross_attention_dim is None:
158
+ attn_procs[name] = LoRAAttnProcessor(
159
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
160
+ ).to(self.device, dtype=self.torch_dtype)
161
+ else:
162
+ attn_procs[name] = LoRAIPAttnProcessor(
163
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
164
+ ).to(self.device, dtype=self.torch_dtype)
165
+ unet.set_attn_processor(attn_procs)
166
+
167
+ def load_ip_adapter(self):
168
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
169
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
170
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
171
+ for key in f.keys():
172
+ if key.startswith("image_proj."):
173
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
174
+ elif key.startswith("ip_adapter."):
175
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
176
+ else:
177
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
178
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
179
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
180
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
181
+
182
+ @torch.inference_mode()
183
+ def get_image_embeds(self, faceid_embeds):
184
+
185
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
186
+ image_prompt_embeds = self.image_proj_model(faceid_embeds)
187
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
188
+ return image_prompt_embeds, uncond_image_prompt_embeds
189
+
190
+ def set_scale(self, scale):
191
+ for attn_processor in self.pipe.unet.attn_processors.values():
192
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
193
+ attn_processor.scale = scale
194
+
195
+ def generate(
196
+ self,
197
+ faceid_embeds=None,
198
+ prompt=None,
199
+ negative_prompt=None,
200
+ scale=1.0,
201
+ num_samples=4,
202
+ seed=None,
203
+ guidance_scale=7.5,
204
+ num_inference_steps=30,
205
+ **kwargs,
206
+ ):
207
+ self.set_scale(scale)
208
+
209
+
210
+ num_prompts = faceid_embeds.size(0)
211
+
212
+ if prompt is None:
213
+ prompt = "best quality, high quality"
214
+ if negative_prompt is None:
215
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
216
+
217
+ if not isinstance(prompt, List):
218
+ prompt = [prompt] * num_prompts
219
+ if not isinstance(negative_prompt, List):
220
+ negative_prompt = [negative_prompt] * num_prompts
221
+
222
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
223
+
224
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
225
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
226
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
227
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
228
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
229
+
230
+ with torch.inference_mode():
231
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
232
+ prompt,
233
+ device=self.device,
234
+ num_images_per_prompt=num_samples,
235
+ do_classifier_free_guidance=True,
236
+ negative_prompt=negative_prompt,
237
+ )
238
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
239
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
240
+
241
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
242
+ images = self.pipe(
243
+ prompt_embeds=prompt_embeds,
244
+ negative_prompt_embeds=negative_prompt_embeds,
245
+ guidance_scale=guidance_scale,
246
+ num_inference_steps=num_inference_steps,
247
+ generator=generator,
248
+ **kwargs,
249
+ ).images
250
+
251
+ return images
252
+
253
+
254
+ class IPAdapterFaceIDPlus:
255
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
256
+ self.device = device
257
+ self.image_encoder_path = image_encoder_path
258
+ self.ip_ckpt = ip_ckpt
259
+ self.lora_rank = lora_rank
260
+ self.num_tokens = num_tokens
261
+ self.torch_dtype = torch_dtype
262
+
263
+ self.pipe = sd_pipe.to(self.device)
264
+ self.set_ip_adapter()
265
+
266
+ # load image encoder
267
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
268
+ self.device, dtype=self.torch_dtype
269
+ )
270
+ self.clip_image_processor = CLIPImageProcessor()
271
+ # image proj model
272
+ self.image_proj_model = self.init_proj()
273
+
274
+ self.load_ip_adapter()
275
+
276
+ def init_proj(self):
277
+ image_proj_model = ProjPlusModel(
278
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
279
+ id_embeddings_dim=512,
280
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
281
+ num_tokens=self.num_tokens,
282
+ ).to(self.device, dtype=self.torch_dtype)
283
+ return image_proj_model
284
+
285
+ def set_ip_adapter(self):
286
+ unet = self.pipe.unet
287
+ attn_procs = {}
288
+ for name in unet.attn_processors.keys():
289
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
290
+ if name.startswith("mid_block"):
291
+ hidden_size = unet.config.block_out_channels[-1]
292
+ elif name.startswith("up_blocks"):
293
+ block_id = int(name[len("up_blocks.")])
294
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
295
+ elif name.startswith("down_blocks"):
296
+ block_id = int(name[len("down_blocks.")])
297
+ hidden_size = unet.config.block_out_channels[block_id]
298
+ if cross_attention_dim is None:
299
+ attn_procs[name] = LoRAAttnProcessor(
300
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
301
+ ).to(self.device, dtype=self.torch_dtype)
302
+ else:
303
+ attn_procs[name] = LoRAIPAttnProcessor(
304
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
305
+ ).to(self.device, dtype=self.torch_dtype)
306
+ unet.set_attn_processor(attn_procs)
307
+
308
+ def load_ip_adapter(self):
309
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
310
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
311
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
312
+ for key in f.keys():
313
+ if key.startswith("image_proj."):
314
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
315
+ elif key.startswith("ip_adapter."):
316
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
317
+ else:
318
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
319
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
320
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
321
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
322
+
323
+ @torch.inference_mode()
324
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
325
+ if isinstance(face_image, Image.Image):
326
+ pil_image = [face_image]
327
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
328
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
329
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
330
+ uncond_clip_image_embeds = self.image_encoder(
331
+ torch.zeros_like(clip_image), output_hidden_states=True
332
+ ).hidden_states[-2]
333
+
334
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
335
+ image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
336
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
337
+ return image_prompt_embeds, uncond_image_prompt_embeds
338
+
339
+ def set_scale(self, scale):
340
+ for attn_processor in self.pipe.unet.attn_processors.values():
341
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
342
+ attn_processor.scale = scale
343
+
344
+ def generate(
345
+ self,
346
+ face_image=None,
347
+ faceid_embeds=None,
348
+ prompt=None,
349
+ negative_prompt=None,
350
+ scale=1.0,
351
+ num_samples=4,
352
+ seed=None,
353
+ guidance_scale=7.5,
354
+ num_inference_steps=30,
355
+ s_scale=1.0,
356
+ shortcut=False,
357
+ **kwargs,
358
+ ):
359
+ self.set_scale(scale)
360
+
361
+
362
+ num_prompts = faceid_embeds.size(0)
363
+
364
+ if prompt is None:
365
+ prompt = "best quality, high quality"
366
+ if negative_prompt is None:
367
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
368
+
369
+ if not isinstance(prompt, List):
370
+ prompt = [prompt] * num_prompts
371
+ if not isinstance(negative_prompt, List):
372
+ negative_prompt = [negative_prompt] * num_prompts
373
+
374
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
375
+
376
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
377
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
378
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
379
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
380
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
381
+
382
+ with torch.inference_mode():
383
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
384
+ prompt,
385
+ device=self.device,
386
+ num_images_per_prompt=num_samples,
387
+ do_classifier_free_guidance=True,
388
+ negative_prompt=negative_prompt,
389
+ )
390
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
391
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
392
+
393
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
394
+ images = self.pipe(
395
+ prompt_embeds=prompt_embeds,
396
+ negative_prompt_embeds=negative_prompt_embeds,
397
+ guidance_scale=guidance_scale,
398
+ num_inference_steps=num_inference_steps,
399
+ generator=generator,
400
+ **kwargs,
401
+ ).images
402
+
403
+ return images
404
+
405
+
406
+ class IPAdapterFaceIDXL(IPAdapterFaceID):
407
+ """SDXL"""
408
+
409
+ def generate(
410
+ self,
411
+ faceid_embeds=None,
412
+ prompt=None,
413
+ negative_prompt=None,
414
+ scale=1.0,
415
+ num_samples=4,
416
+ seed=None,
417
+ num_inference_steps=30,
418
+ **kwargs,
419
+ ):
420
+ self.set_scale(scale)
421
+
422
+ num_prompts = faceid_embeds.size(0)
423
+
424
+ if prompt is None:
425
+ prompt = "best quality, high quality"
426
+ if negative_prompt is None:
427
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
428
+
429
+ if not isinstance(prompt, List):
430
+ prompt = [prompt] * num_prompts
431
+ if not isinstance(negative_prompt, List):
432
+ negative_prompt = [negative_prompt] * num_prompts
433
+
434
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
435
+
436
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
437
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
438
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
439
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
440
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
441
+
442
+ with torch.inference_mode():
443
+ (
444
+ prompt_embeds,
445
+ negative_prompt_embeds,
446
+ pooled_prompt_embeds,
447
+ negative_pooled_prompt_embeds,
448
+ ) = self.pipe.encode_prompt(
449
+ prompt,
450
+ num_images_per_prompt=num_samples,
451
+ do_classifier_free_guidance=True,
452
+ negative_prompt=negative_prompt,
453
+ )
454
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
455
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
456
+
457
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
458
+ images = self.pipe(
459
+ prompt_embeds=prompt_embeds,
460
+ negative_prompt_embeds=negative_prompt_embeds,
461
+ pooled_prompt_embeds=pooled_prompt_embeds,
462
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
463
+ num_inference_steps=num_inference_steps,
464
+ generator=generator,
465
+ **kwargs,
466
+ ).images
467
+
468
+ return images