Spaces:
Running
on
A10G
Running
on
A10G
Linoy Tsaban
commited on
Commit
•
4065064
1
Parent(s):
45e73ca
Update pipeline_semantic_stable_diffusion_img2img_solver.py
Browse files
pipeline_semantic_stable_diffusion_img2img_solver.py
CHANGED
@@ -500,6 +500,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
500 |
use_cross_attn_mask: bool = False,
|
501 |
# Attention store (just for visualization purposes)
|
502 |
attention_store = None,
|
|
|
503 |
attn_store_steps: Optional[List[int]] = [],
|
504 |
store_averaged_over_steps: bool = True,
|
505 |
use_intersect_mask: bool = False,
|
@@ -755,10 +756,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
755 |
# For classifier free guidance, we need to do two forward passes.
|
756 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
757 |
# to avoid doing two forward passes
|
758 |
-
|
759 |
if enable_edit_guidance:
|
760 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
761 |
-
|
762 |
([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
|
763 |
else:
|
764 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
@@ -920,11 +921,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
920 |
if use_cross_attn_mask:
|
921 |
out = attention_store.aggregate_attention(
|
922 |
attention_maps=attention_store.step_store,
|
923 |
-
prompts=
|
924 |
res=16,
|
925 |
from_where=["up", "down"],
|
926 |
is_cross=True,
|
927 |
-
select=
|
928 |
)
|
929 |
attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
|
930 |
|
@@ -1105,7 +1106,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
1105 |
if not return_dict:
|
1106 |
return (image, has_nsfw_concept), attention_store
|
1107 |
|
1108 |
-
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
|
1109 |
|
1110 |
def encode_text(self, prompts):
|
1111 |
text_inputs = self.tokenizer(
|
|
|
500 |
use_cross_attn_mask: bool = False,
|
501 |
# Attention store (just for visualization purposes)
|
502 |
attention_store = None,
|
503 |
+
text_cross_attention_maps = None,
|
504 |
attn_store_steps: Optional[List[int]] = [],
|
505 |
store_averaged_over_steps: bool = True,
|
506 |
use_intersect_mask: bool = False,
|
|
|
756 |
# For classifier free guidance, we need to do two forward passes.
|
757 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
758 |
# to avoid doing two forward passes
|
759 |
+
text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
|
760 |
if enable_edit_guidance:
|
761 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
762 |
+
text_cross_attention_maps += \
|
763 |
([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
|
764 |
else:
|
765 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
|
921 |
if use_cross_attn_mask:
|
922 |
out = attention_store.aggregate_attention(
|
923 |
attention_maps=attention_store.step_store,
|
924 |
+
prompts=text_cross_attention_maps,
|
925 |
res=16,
|
926 |
from_where=["up", "down"],
|
927 |
is_cross=True,
|
928 |
+
select=text_cross_attention_maps.index(editing_prompt[c]),
|
929 |
)
|
930 |
attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
|
931 |
|
|
|
1106 |
if not return_dict:
|
1107 |
return (image, has_nsfw_concept), attention_store
|
1108 |
|
1109 |
+
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store, text_cross_attention_maps
|
1110 |
|
1111 |
def encode_text(self, prompts):
|
1112 |
text_inputs = self.tokenizer(
|