Spaces:
Runtime error
Runtime error
Upload 30 files
Browse files- .gitattributes +2 -0
- README.md +2 -9
- app_ead_instuct.py +620 -0
- images/214000000000.jpg +0 -0
- images/311000000002.jpg +0 -0
- images/Doom_Slayer.jpg +0 -0
- images/Elon_Musk.webp +0 -0
- images/InfEdit.jpg +3 -0
- images/angry.jpg +0 -0
- images/bear.jpg +0 -0
- images/computer.png +0 -0
- images/corgi.jpg +0 -0
- images/dragon.jpg +0 -0
- images/droplet.png +0 -0
- images/frieren.jpg +0 -0
- images/genshin.png +0 -0
- images/groundhog.png +0 -0
- images/james.jpg +0 -0
- images/miku.png +0 -0
- images/moyu.png +0 -0
- images/muffin.png +0 -0
- images/osu.jfif +0 -0
- images/sam.png +3 -0
- images/summer.jpg +0 -0
- nsfw.png +0 -0
- pipeline_ead.py +707 -0
- ptp_utils.py +180 -0
- requirements.txt +8 -0
- seq_aligner.py +314 -0
- utils.py +6 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/InfEdit.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/sam.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
---
|
2 |
title: InfEdit
|
3 |
-
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: cc-by-nc-sa-4.0
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: InfEdit
|
3 |
+
app_file: app_ead_instuct.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.7.1
|
|
|
|
|
|
|
6 |
---
|
|
|
|
app_ead_instuct.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import LCMScheduler
|
2 |
+
from pipeline_ead import EditPipeline
|
3 |
+
import os
|
4 |
+
import gradio as gr
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn.functional as nnf
|
8 |
+
from typing import Optional, Union, Tuple, List, Callable, Dict
|
9 |
+
import abc
|
10 |
+
import ptp_utils
|
11 |
+
import utils
|
12 |
+
import numpy as np
|
13 |
+
import seq_aligner
|
14 |
+
import math
|
15 |
+
|
16 |
+
LOW_RESOURCE = False
|
17 |
+
MAX_NUM_WORDS = 77
|
18 |
+
|
19 |
+
is_colab = utils.is_google_colab()
|
20 |
+
colab_instruction = "" if is_colab else """
|
21 |
+
Colab Instuction"""
|
22 |
+
|
23 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
24 |
+
model_id_or_path = "SimianLuo/LCM_Dreamshaper_v7"
|
25 |
+
device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
27 |
+
|
28 |
+
if is_colab:
|
29 |
+
scheduler = LCMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
30 |
+
pipe = EditPipeline.from_pretrained(model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype)
|
31 |
+
else:
|
32 |
+
# import streamlit as st
|
33 |
+
# scheduler = DDIMScheduler.from_config(model_id_or_path, use_auth_token=st.secrets["USER_TOKEN"], subfolder="scheduler")
|
34 |
+
# pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, use_auth_token=st.secrets["USER_TOKEN"], scheduler=scheduler, torch_dtype=torch_dtype)
|
35 |
+
scheduler = LCMScheduler.from_config(model_id_or_path, use_auth_token=os.environ.get("USER_TOKEN"), subfolder="scheduler")
|
36 |
+
pipe = EditPipeline.from_pretrained(model_id_or_path, use_auth_token=os.environ.get("USER_TOKEN"), scheduler=scheduler, torch_dtype=torch_dtype)
|
37 |
+
|
38 |
+
tokenizer = pipe.tokenizer
|
39 |
+
encoder = pipe.text_encoder
|
40 |
+
|
41 |
+
if torch.cuda.is_available():
|
42 |
+
pipe = pipe.to("cuda")
|
43 |
+
|
44 |
+
|
45 |
+
class LocalBlend:
|
46 |
+
|
47 |
+
def get_mask(self,x_t,maps,word_idx, thresh, i):
|
48 |
+
# print(word_idx)
|
49 |
+
# print(maps.shape)
|
50 |
+
# for i in range(0,self.len):
|
51 |
+
# self.save_image(maps[:,:,:,:,i].mean(0,keepdim=True),i,"map")
|
52 |
+
maps = maps * word_idx.reshape(1,1,1,1,-1)
|
53 |
+
maps = (maps[:,:,:,:,1:self.len-1]).mean(0,keepdim=True)
|
54 |
+
# maps = maps.mean(0,keepdim=True)
|
55 |
+
maps = (maps).max(-1)[0]
|
56 |
+
# self.save_image(maps,i,"map")
|
57 |
+
maps = nnf.interpolate(maps, size=(x_t.shape[2:]))
|
58 |
+
# maps = maps.mean(1,keepdim=True)\
|
59 |
+
maps = maps / maps.max(2, keepdim=True)[0].max(3, keepdim=True)[0]
|
60 |
+
mask = maps > thresh
|
61 |
+
return mask
|
62 |
+
|
63 |
+
|
64 |
+
def save_image(self,mask,i, caption):
|
65 |
+
image = mask[0, 0, :, :]
|
66 |
+
image = 255 * image / image.max()
|
67 |
+
# print(image.shape)
|
68 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
69 |
+
# print(image.shape)
|
70 |
+
image = image.cpu().numpy().astype(np.uint8)
|
71 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
72 |
+
if not os.path.exists(f"inter/{caption}"):
|
73 |
+
os.mkdir(f"inter/{caption}")
|
74 |
+
ptp_utils.save_images(image, f"inter/{caption}/{i}.jpg")
|
75 |
+
|
76 |
+
|
77 |
+
def __call__(self, i, x_s, x_t, x_m, attention_store, alpha_prod, temperature=0.15, use_xm=False):
|
78 |
+
maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
|
79 |
+
h,w = x_t.shape[2],x_t.shape[3]
|
80 |
+
h , w = ((h+1)//2+1)//2, ((w+1)//2+1)//2
|
81 |
+
# print(h,w)
|
82 |
+
# print(maps[0].shape)
|
83 |
+
maps = [item.reshape(2, -1, 1, h // int((h*w/item.shape[-2])**0.5), w // int((h*w/item.shape[-2])**0.5), MAX_NUM_WORDS) for item in maps]
|
84 |
+
maps = torch.cat(maps, dim=1)
|
85 |
+
maps_s = maps[0,:]
|
86 |
+
maps_m = maps[1,:]
|
87 |
+
thresh_e = temperature / alpha_prod ** (0.5)
|
88 |
+
if thresh_e < self.thresh_e:
|
89 |
+
thresh_e = self.thresh_e
|
90 |
+
thresh_m = self.thresh_m
|
91 |
+
mask_e = self.get_mask(x_t, maps_m, self.alpha_e, thresh_e, i)
|
92 |
+
mask_m = self.get_mask(x_t, maps_s, (self.alpha_m-self.alpha_me), thresh_m, i)
|
93 |
+
mask_me = self.get_mask(x_t, maps_m, self.alpha_me, self.thresh_e, i)
|
94 |
+
if self.save_inter:
|
95 |
+
self.save_image(mask_e,i,"mask_e")
|
96 |
+
self.save_image(mask_m,i,"mask_m")
|
97 |
+
self.save_image(mask_me,i,"mask_me")
|
98 |
+
|
99 |
+
if self.alpha_e.sum() == 0:
|
100 |
+
x_t_out = x_t
|
101 |
+
else:
|
102 |
+
x_t_out = torch.where(mask_e, x_t, x_m)
|
103 |
+
x_t_out = torch.where(mask_m, x_s, x_t_out)
|
104 |
+
if use_xm:
|
105 |
+
x_t_out = torch.where(mask_me, x_m, x_t_out)
|
106 |
+
|
107 |
+
return x_m, x_t_out
|
108 |
+
|
109 |
+
def __init__(self,thresh_e=0.3, thresh_m=0.3, save_inter = False):
|
110 |
+
self.thresh_e = thresh_e
|
111 |
+
self.thresh_m = thresh_m
|
112 |
+
self.save_inter = save_inter
|
113 |
+
|
114 |
+
def set_map(self, ms, alpha, alpha_e, alpha_m,len):
|
115 |
+
self.m = ms
|
116 |
+
self.alpha = alpha
|
117 |
+
self.alpha_e = alpha_e
|
118 |
+
self.alpha_m = alpha_m
|
119 |
+
alpha_me = alpha_e.to(torch.bool) & alpha_m.to(torch.bool)
|
120 |
+
self.alpha_me = alpha_me.to(torch.float)
|
121 |
+
self.len = len
|
122 |
+
|
123 |
+
|
124 |
+
class AttentionControl(abc.ABC):
|
125 |
+
|
126 |
+
def step_callback(self, x_t):
|
127 |
+
return x_t
|
128 |
+
|
129 |
+
def between_steps(self):
|
130 |
+
return
|
131 |
+
|
132 |
+
@property
|
133 |
+
def num_uncond_att_layers(self):
|
134 |
+
return self.num_att_layers if LOW_RESOURCE else 0
|
135 |
+
|
136 |
+
@abc.abstractmethod
|
137 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
141 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
142 |
+
if LOW_RESOURCE:
|
143 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
144 |
+
else:
|
145 |
+
h = attn.shape[0]
|
146 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
147 |
+
self.cur_att_layer += 1
|
148 |
+
if self.cur_att_layer == self.num_att_layers // 2 + self.num_uncond_att_layers:
|
149 |
+
self.cur_att_layer = 0
|
150 |
+
self.cur_step += 1
|
151 |
+
self.between_steps()
|
152 |
+
return attn
|
153 |
+
|
154 |
+
def reset(self):
|
155 |
+
self.cur_step = 0
|
156 |
+
self.cur_att_layer = 0
|
157 |
+
|
158 |
+
def __init__(self):
|
159 |
+
self.cur_step = 0
|
160 |
+
self.num_att_layers = -1
|
161 |
+
self.cur_att_layer = 0
|
162 |
+
|
163 |
+
|
164 |
+
class EmptyControl(AttentionControl):
|
165 |
+
|
166 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
167 |
+
return attn
|
168 |
+
def self_attn_forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
169 |
+
b = q.shape[0] // num_heads
|
170 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
171 |
+
return out
|
172 |
+
|
173 |
+
|
174 |
+
class AttentionStore(AttentionControl):
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def get_empty_store():
|
178 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
179 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
180 |
+
|
181 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
182 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
183 |
+
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
184 |
+
self.step_store[key].append(attn)
|
185 |
+
return attn
|
186 |
+
|
187 |
+
def between_steps(self):
|
188 |
+
if len(self.attention_store) == 0:
|
189 |
+
self.attention_store = self.step_store
|
190 |
+
else:
|
191 |
+
for key in self.attention_store:
|
192 |
+
for i in range(len(self.attention_store[key])):
|
193 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
194 |
+
self.step_store = self.get_empty_store()
|
195 |
+
|
196 |
+
def get_average_attention(self):
|
197 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
|
198 |
+
return average_attention
|
199 |
+
|
200 |
+
def reset(self):
|
201 |
+
super(AttentionStore, self).reset()
|
202 |
+
self.step_store = self.get_empty_store()
|
203 |
+
self.attention_store = {}
|
204 |
+
|
205 |
+
def __init__(self):
|
206 |
+
super(AttentionStore, self).__init__()
|
207 |
+
self.step_store = self.get_empty_store()
|
208 |
+
self.attention_store = {}
|
209 |
+
|
210 |
+
|
211 |
+
class AttentionControlEdit(AttentionStore, abc.ABC):
|
212 |
+
|
213 |
+
def step_callback(self,i, t, x_s, x_t, x_m, alpha_prod):
|
214 |
+
if (self.local_blend is not None) and (i>0):
|
215 |
+
use_xm = (self.cur_step+self.start_steps+1 == self.num_steps)
|
216 |
+
x_m, x_t = self.local_blend(i, x_s, x_t, x_m, self.attention_store, alpha_prod, use_xm=use_xm)
|
217 |
+
return x_m, x_t
|
218 |
+
|
219 |
+
def replace_self_attention(self, attn_base, att_replace):
|
220 |
+
if att_replace.shape[2] <= 16 ** 2:
|
221 |
+
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
|
222 |
+
else:
|
223 |
+
return att_replace
|
224 |
+
|
225 |
+
@abc.abstractmethod
|
226 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
227 |
+
raise NotImplementedError
|
228 |
+
|
229 |
+
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
230 |
+
b = q.shape[0] // num_heads
|
231 |
+
|
232 |
+
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
|
233 |
+
attn = sim.softmax(-1)
|
234 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
235 |
+
return out
|
236 |
+
|
237 |
+
def self_attn_forward(self, q, k, v, num_heads):
|
238 |
+
if q.shape[0]//num_heads == 3:
|
239 |
+
if (self.self_replace_steps <= ((self.cur_step+self.start_steps+1)*1.0 / self.num_steps) ):
|
240 |
+
q=torch.cat([q[:num_heads*2],q[num_heads:num_heads*2]])
|
241 |
+
k=torch.cat([k[:num_heads*2],k[:num_heads]])
|
242 |
+
v=torch.cat([v[:num_heads*2],v[:num_heads]])
|
243 |
+
else:
|
244 |
+
q=torch.cat([q[:num_heads],q[:num_heads],q[:num_heads]])
|
245 |
+
k=torch.cat([k[:num_heads],k[:num_heads],k[:num_heads]])
|
246 |
+
v=torch.cat([v[:num_heads*2],v[:num_heads]])
|
247 |
+
return q,k,v
|
248 |
+
else:
|
249 |
+
qu, qc = q.chunk(2)
|
250 |
+
ku, kc = k.chunk(2)
|
251 |
+
vu, vc = v.chunk(2)
|
252 |
+
if (self.self_replace_steps <= ((self.cur_step+self.start_steps+1)*1.0 / self.num_steps) ):
|
253 |
+
qu=torch.cat([qu[:num_heads*2],qu[num_heads:num_heads*2]])
|
254 |
+
qc=torch.cat([qc[:num_heads*2],qc[num_heads:num_heads*2]])
|
255 |
+
ku=torch.cat([ku[:num_heads*2],ku[:num_heads]])
|
256 |
+
kc=torch.cat([kc[:num_heads*2],kc[:num_heads]])
|
257 |
+
vu=torch.cat([vu[:num_heads*2],vu[:num_heads]])
|
258 |
+
vc=torch.cat([vc[:num_heads*2],vc[:num_heads]])
|
259 |
+
else:
|
260 |
+
qu=torch.cat([qu[:num_heads],qu[:num_heads],qu[:num_heads]])
|
261 |
+
qc=torch.cat([qc[:num_heads],qc[:num_heads],qc[:num_heads]])
|
262 |
+
ku=torch.cat([ku[:num_heads],ku[:num_heads],ku[:num_heads]])
|
263 |
+
kc=torch.cat([kc[:num_heads],kc[:num_heads],kc[:num_heads]])
|
264 |
+
vu=torch.cat([vu[:num_heads*2],vu[:num_heads]])
|
265 |
+
vc=torch.cat([vc[:num_heads*2],vc[:num_heads]])
|
266 |
+
|
267 |
+
return torch.cat([qu, qc], dim=0) ,torch.cat([ku, kc], dim=0), torch.cat([vu, vc], dim=0)
|
268 |
+
|
269 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
270 |
+
if is_cross :
|
271 |
+
h = attn.shape[0] // self.batch_size
|
272 |
+
attn = attn.reshape(self.batch_size,h, *attn.shape[1:])
|
273 |
+
attn_base, attn_repalce,attn_masa = attn[0], attn[1], attn[2]
|
274 |
+
attn_replace_new = self.replace_cross_attention(attn_masa, attn_repalce)
|
275 |
+
attn_base_store = self.replace_cross_attention(attn_base, attn_repalce)
|
276 |
+
if (self.cross_replace_steps >= ((self.cur_step+self.start_steps+1)*1.0 / self.num_steps) ):
|
277 |
+
attn[1] = attn_base_store
|
278 |
+
attn_store=torch.cat([attn_base_store,attn_replace_new])
|
279 |
+
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
|
280 |
+
attn_store = attn_store.reshape(2 *h, *attn_store.shape[2:])
|
281 |
+
super(AttentionControlEdit, self).forward(attn_store, is_cross, place_in_unet)
|
282 |
+
return attn
|
283 |
+
|
284 |
+
def __init__(self, prompts, num_steps: int,start_steps: int,
|
285 |
+
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
286 |
+
self_replace_steps: Union[float, Tuple[float, float]],
|
287 |
+
local_blend: Optional[LocalBlend]):
|
288 |
+
super(AttentionControlEdit, self).__init__()
|
289 |
+
self.batch_size = len(prompts)+1
|
290 |
+
self.self_replace_steps = self_replace_steps
|
291 |
+
self.cross_replace_steps = cross_replace_steps
|
292 |
+
self.num_steps=num_steps
|
293 |
+
self.start_steps=start_steps
|
294 |
+
self.local_blend = local_blend
|
295 |
+
|
296 |
+
|
297 |
+
class AttentionReplace(AttentionControlEdit):
|
298 |
+
|
299 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
300 |
+
return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
|
301 |
+
|
302 |
+
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
|
303 |
+
local_blend: Optional[LocalBlend] = None):
|
304 |
+
super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
|
305 |
+
self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device).to(torch_dtype)
|
306 |
+
|
307 |
+
|
308 |
+
class AttentionRefine(AttentionControlEdit):
|
309 |
+
|
310 |
+
def replace_cross_attention(self, attn_masa, att_replace):
|
311 |
+
attn_masa_replace = attn_masa[:, :, self.mapper].squeeze()
|
312 |
+
attn_replace = attn_masa_replace * self.alphas + \
|
313 |
+
att_replace * (1 - self.alphas)
|
314 |
+
return attn_replace
|
315 |
+
|
316 |
+
def __init__(self, prompts, prompt_specifiers, num_steps: int,start_steps: int, cross_replace_steps: float, self_replace_steps: float,
|
317 |
+
local_blend: Optional[LocalBlend] = None):
|
318 |
+
super(AttentionRefine, self).__init__(prompts, num_steps,start_steps, cross_replace_steps, self_replace_steps, local_blend)
|
319 |
+
self.mapper, alphas, ms, alpha_e, alpha_m = seq_aligner.get_refinement_mapper(prompts, prompt_specifiers, tokenizer, encoder, device)
|
320 |
+
self.mapper, alphas, ms = self.mapper.to(device), alphas.to(device).to(torch_dtype), ms.to(device).to(torch_dtype)
|
321 |
+
self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
|
322 |
+
self.ms = ms.reshape(ms.shape[0], 1, 1, ms.shape[1])
|
323 |
+
ms = ms.to(device)
|
324 |
+
alpha_e = alpha_e.to(device)
|
325 |
+
alpha_m = alpha_m.to(device)
|
326 |
+
t_len = len(tokenizer(prompts[1])["input_ids"])
|
327 |
+
self.local_blend.set_map(ms,alphas,alpha_e,alpha_m,t_len)
|
328 |
+
|
329 |
+
|
330 |
+
def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]]):
|
331 |
+
if type(word_select) is int or type(word_select) is str:
|
332 |
+
word_select = (word_select,)
|
333 |
+
equalizer = torch.ones(len(values), 77)
|
334 |
+
values = torch.tensor(values, dtype=torch_dtype)
|
335 |
+
for word in word_select:
|
336 |
+
inds = ptp_utils.get_word_inds(text, word, tokenizer)
|
337 |
+
equalizer[:, inds] = values
|
338 |
+
return equalizer
|
339 |
+
|
340 |
+
|
341 |
+
def inference(img, source_prompt, target_prompt,
|
342 |
+
local, mutual,
|
343 |
+
positive_prompt, negative_prompt,
|
344 |
+
guidance_s, guidance_t,
|
345 |
+
num_inference_steps,
|
346 |
+
width, height, seed, strength,
|
347 |
+
cross_replace_steps, self_replace_steps,
|
348 |
+
thresh_e, thresh_m, denoise, user_instruct="", api_key=""):
|
349 |
+
print(img)
|
350 |
+
if user_instruct != "" and api_key != "":
|
351 |
+
source_prompt, target_prompt, local, mutual, replace_steps, num_inference_steps = get_params(api_key, user_instruct)
|
352 |
+
cross_replace_steps = replace_steps
|
353 |
+
self_replace_steps = replace_steps
|
354 |
+
|
355 |
+
torch.manual_seed(seed)
|
356 |
+
ratio = min(height / img.height, width / img.width)
|
357 |
+
img = img.resize((int(img.width * ratio), int(img.height * ratio)))
|
358 |
+
if denoise is False:
|
359 |
+
strength = 1
|
360 |
+
num_denoise_num = math.trunc(num_inference_steps*strength)
|
361 |
+
num_start = num_inference_steps-num_denoise_num
|
362 |
+
# create the CAC controller.
|
363 |
+
local_blend = LocalBlend(thresh_e=thresh_e, thresh_m=thresh_m, save_inter=False)
|
364 |
+
controller = AttentionRefine([source_prompt, target_prompt],[[local, mutual]],
|
365 |
+
num_inference_steps,
|
366 |
+
num_start,
|
367 |
+
cross_replace_steps=cross_replace_steps,
|
368 |
+
self_replace_steps=self_replace_steps,
|
369 |
+
local_blend=local_blend
|
370 |
+
)
|
371 |
+
ptp_utils.register_attention_control(pipe, controller)
|
372 |
+
|
373 |
+
results = pipe(prompt=target_prompt,
|
374 |
+
source_prompt=source_prompt,
|
375 |
+
positive_prompt=positive_prompt,
|
376 |
+
negative_prompt=negative_prompt,
|
377 |
+
image=img,
|
378 |
+
num_inference_steps=num_inference_steps,
|
379 |
+
eta=1,
|
380 |
+
strength=strength,
|
381 |
+
guidance_scale=guidance_t,
|
382 |
+
source_guidance_scale=guidance_s,
|
383 |
+
denoise_model=denoise,
|
384 |
+
callback = controller.step_callback
|
385 |
+
)
|
386 |
+
|
387 |
+
return replace_nsfw_images(results)
|
388 |
+
|
389 |
+
|
390 |
+
def replace_nsfw_images(results):
|
391 |
+
for i in range(len(results.images)):
|
392 |
+
if results.nsfw_content_detected[i]:
|
393 |
+
results.images[i] = Image.open("nsfw.png")
|
394 |
+
return results.images[0]
|
395 |
+
|
396 |
+
|
397 |
+
css = """.cycle-diffusion-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.cycle-diffusion-div div h1{font-weight:900;margin-bottom:7px}.cycle-diffusion-div p{margin-bottom:10px;font-size:94%}.cycle-diffusion-div p a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
|
398 |
+
"""
|
399 |
+
intro = """
|
400 |
+
<div style="display: flex;align-items: center;justify-content: center">
|
401 |
+
<img src="https://sled-group.github.io/InfEdit/image_assets/InfEdit.png" width="80" style="display: inline-block">
|
402 |
+
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">InfEdit</h1>
|
403 |
+
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Inversion-Free Image Editing
|
404 |
+
with Natural Language</h3>
|
405 |
+
</div>
|
406 |
+
"""
|
407 |
+
|
408 |
+
param_bot_prompt = """
|
409 |
+
You are a helpful assistant named InfEdit that provides input parameters to the image editing model based on user instructions. You should respond in valid json format.
|
410 |
+
|
411 |
+
User:
|
412 |
+
```
|
413 |
+
{image descrption and editing commands | example: 'The image shows an apple on the table and I want to change the apple to a banana.'}
|
414 |
+
```
|
415 |
+
|
416 |
+
After receiving this, you will need to generate the appropriate params as input to the image editing models.
|
417 |
+
|
418 |
+
Assistant:
|
419 |
+
```
|
420 |
+
{
|
421 |
+
“source_prompt”: “{a string describes the input image, it needs to includes the thing user want to change | example: 'an apple on the table'}”,
|
422 |
+
“target_prompt”: “{a string that matches the source prompt, but it needs to includes the thing user want to change | example: 'a banana on the table'}”,
|
423 |
+
“target_sub”: “{a special substring from the target prompt}”,
|
424 |
+
“mutual_sub”: “{a special mutual substring from source/target prompt}”
|
425 |
+
“attention_control”: {a number between 0 and 1}
|
426 |
+
“steps”: {a number between 8 and 50}
|
427 |
+
}
|
428 |
+
```
|
429 |
+
|
430 |
+
You need to fill in the "target_sub" and "mutual_sub" by the guideline below.
|
431 |
+
|
432 |
+
If the editing instruction is not about changing style or background:
|
433 |
+
- The "target_sub" should be a special substring from the target prompt that highlights what you want to edit, it should be as short as possible and should only be noun ("banana" instead of "a banana").
|
434 |
+
- The "mutual_sub" should be kept as an empty string.
|
435 |
+
P.S. When you want to remove something, it's always better to use "empty", "nothing" or some appropriate words to replace it. Like remove an apple on the table, you can use "an apple on the table" and "nothing on the table" as your prompts, and use "nothing" as your target_sub.
|
436 |
+
P.S. You should think carefully about what you want to modify, like "short hair" to "long hair", your target_sub should be "hair" instead of "long".
|
437 |
+
P.S. When you are adding something, the target_sub should be the thing you want to add.
|
438 |
+
|
439 |
+
If it's about style editing:
|
440 |
+
- The "target_sub" should be kept as an empty string.
|
441 |
+
- The "mutual_sub" should be kept as an empty string.
|
442 |
+
|
443 |
+
If it's about background editing:
|
444 |
+
- The "target_sub" should be kept as an empty string.
|
445 |
+
- The "mutual_sub" should be a common substring from source/target prompt, and is the main object/character (noun) in the image. It should be as short as possible and only be noun ("banana" instead of "a banana", "man" instead of "running man").
|
446 |
+
|
447 |
+
A specific case, if it's about change an object's abstract information, like pose, view or shape and want to keep the semantic feature same, like a dog to a running dog,
|
448 |
+
- The "target_sub" should be a special substring from the target prompt that highlights what you want to edit, it should be as short as possible and should only be noun ("dog" instead of "a running dog").
|
449 |
+
- The "mutual_sub" should be as same as target_sub because we want to "edit the dog but also keep the dog as same".
|
450 |
+
|
451 |
+
|
452 |
+
You need to choose a specific value of “attention_control” by the guideline below.
|
453 |
+
A larger value of “attention_control” means more consistency between the source image and the output.
|
454 |
+
|
455 |
+
- the editing is on the feature level, like color, material and so on, and want to ensure the characteristics of the original object as much as possible, you should choose a large value. (Example: for color editing, you can choose 1, and for material you can choose 0.9)
|
456 |
+
- the editing is on the object level, like edit a "cat" to a "dog", or a "horse" to a "zebra", and want to make them to be similar, you need to choose a relatively large value, we say 0.7 for example.
|
457 |
+
- the editing is changing the style but want to keep the spatial features, you need to choose a relatively large value, we say 0.7 for example.
|
458 |
+
- the editing need to change something's shape, like edit an "apple" to a "banana", a "flower" to a "knife", "short" hair to "long" hair, "round" to "square", which have very different shapes, you need to choose a relatively small value, we say 0.3 for example.
|
459 |
+
- the editing is tring to change the spatial information, like change the pose and so on, you need to choose a relatively small value, we say 0.3 for example.
|
460 |
+
- the editing should not consider the consistency with the input image, like add something new, remove something, or change the background, you can directly use 0.
|
461 |
+
|
462 |
+
|
463 |
+
You need to choose a specific value of “steps” by the guideline below.
|
464 |
+
More steps mean that the edit effect is more pronounced.
|
465 |
+
- If the editing is super easy, like changing something to something with very similar features, you can choose 8 steps.
|
466 |
+
- In most cases, you can choose 15 steps.
|
467 |
+
- For style editing and remove tasks, you can choose a larger value, like 25 steps.
|
468 |
+
- If you feel the task is extremely difficult (like some kinds of styles or removing very tiny stuffs), you can directly use 50 steps.
|
469 |
+
"""
|
470 |
+
def get_params(api_key, user_instruct):
|
471 |
+
from openai import OpenAI
|
472 |
+
client = OpenAI(api_key=api_key)
|
473 |
+
print("user_instruct", user_instruct)
|
474 |
+
response = client.chat.completions.create(
|
475 |
+
model="gpt-4-1106-preview",
|
476 |
+
messages=[
|
477 |
+
{"role": "system", "content": param_bot_prompt},
|
478 |
+
{"role": "user", "content": user_instruct}
|
479 |
+
],
|
480 |
+
response_format={ "type": "json_object" },
|
481 |
+
)
|
482 |
+
param_dict = response.choices[0].message.content
|
483 |
+
print("param_dict", param_dict)
|
484 |
+
import json
|
485 |
+
param_dict = json.loads(param_dict)
|
486 |
+
return param_dict['source_prompt'], param_dict['target_prompt'], param_dict['target_sub'], param_dict['mutual_sub'], param_dict['attention_control'], param_dict['steps']
|
487 |
+
with gr.Blocks(css=css) as demo:
|
488 |
+
gr.HTML(intro)
|
489 |
+
with gr.Accordion("README", open=False):
|
490 |
+
gr.HTML(
|
491 |
+
"""
|
492 |
+
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
|
493 |
+
<a href="https://sled-group.github.io/InfEdit/" target="_blank">project page</a> | <a href="https://arxiv.org" target="_blank">paper</a>| <a href="https://github.com/sled-group/InfEdit/tree/website" target="_blank">handbook</a>
|
494 |
+
</p>
|
495 |
+
|
496 |
+
We are now hosting on a A4000 GPU with 16 GiB memory.
|
497 |
+
"""
|
498 |
+
)
|
499 |
+
with gr.Row():
|
500 |
+
|
501 |
+
with gr.Column(scale=55):
|
502 |
+
with gr.Group():
|
503 |
+
|
504 |
+
img = gr.Image(label="Input image", height=512, type="pil")
|
505 |
+
|
506 |
+
image_out = gr.Image(label="Output image", height=512)
|
507 |
+
# gallery = gr.Gallery(
|
508 |
+
# label="Generated images", show_label=False, elem_id="gallery"
|
509 |
+
# ).style(grid=[1], height="auto")
|
510 |
+
|
511 |
+
with gr.Column(scale=45):
|
512 |
+
|
513 |
+
with gr.Tab("UAC options"):
|
514 |
+
with gr.Group():
|
515 |
+
with gr.Row():
|
516 |
+
source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
|
517 |
+
with gr.Row():
|
518 |
+
guidance_s = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
|
519 |
+
positive_prompt = gr.Textbox(label="Positive prompt", placeholder="")
|
520 |
+
with gr.Row():
|
521 |
+
target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
|
522 |
+
with gr.Row():
|
523 |
+
guidance_t = gr.Slider(label="Target guidance scale", value=2, minimum=1, maximum=10)
|
524 |
+
negative_prompt = gr.Textbox(label="Negative prompt", placeholder="")
|
525 |
+
with gr.Row():
|
526 |
+
local = gr.Textbox(label="Target blend", placeholder="")
|
527 |
+
thresh_e = gr.Slider(label="Target blend thresh", value=0.6, minimum=0, maximum=1)
|
528 |
+
with gr.Row():
|
529 |
+
mutual = gr.Textbox(label="Source blend", placeholder="")
|
530 |
+
thresh_m = gr.Slider(label="Source blend thresh", value=0.6, minimum=0, maximum=1)
|
531 |
+
with gr.Row():
|
532 |
+
cross_replace_steps = gr.Slider(label="Cross attn control schedule", value=0.7, minimum=0.0, maximum=1, step=0.01)
|
533 |
+
self_replace_steps = gr.Slider(label="Self attn control schedule", value=0.3, minimum=0.0, maximum=1, step=0.01)
|
534 |
+
with gr.Row():
|
535 |
+
denoise = gr.Checkbox(label='Denoising Mode', value=False)
|
536 |
+
strength = gr.Slider(label="Strength", value=0.7, minimum=0, maximum=1, step=0.01, visible=False)
|
537 |
+
denoise.change(fn=lambda value: gr.update(visible=value), inputs=denoise, outputs=strength)
|
538 |
+
with gr.Row():
|
539 |
+
generate1 = gr.Button(value="Run")
|
540 |
+
|
541 |
+
with gr.Tab("Advanced options"):
|
542 |
+
with gr.Group():
|
543 |
+
with gr.Row():
|
544 |
+
num_inference_steps = gr.Slider(label="Inference steps", value=15, minimum=1, maximum=50, step=1)
|
545 |
+
width = gr.Slider(label="Width", value=512, minimum=512, maximum=1024, step=8)
|
546 |
+
height = gr.Slider(label="Height", value=512, minimum=512, maximum=1024, step=8)
|
547 |
+
with gr.Row():
|
548 |
+
seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
|
549 |
+
with gr.Row():
|
550 |
+
generate3 = gr.Button(value="Run")
|
551 |
+
|
552 |
+
with gr.Tab("Instruction following (+GPT4)"):
|
553 |
+
guide_str = """Describe the image you uploaded and tell me how you want to edit it."""
|
554 |
+
with gr.Group():
|
555 |
+
api_key = gr.Textbox(label="YOUR OPENAI API KEY", placeholder="sk-xxx", lines = 1, type="password")
|
556 |
+
user_instruct = gr.Textbox(label=guide_str, placeholder="The image shows an apple on the table and I want to change the apple to a banana.", lines = 3)
|
557 |
+
# source_prompt, target_prompt, local, mutual = get_params(api_key, user_instruct)
|
558 |
+
with gr.Row():
|
559 |
+
generate4 = gr.Button(value="Run")
|
560 |
+
|
561 |
+
inputs1 = [img, source_prompt, target_prompt,
|
562 |
+
local, mutual,
|
563 |
+
positive_prompt, negative_prompt,
|
564 |
+
guidance_s, guidance_t,
|
565 |
+
num_inference_steps,
|
566 |
+
width, height, seed, strength,
|
567 |
+
cross_replace_steps, self_replace_steps,
|
568 |
+
thresh_e, thresh_m, denoise]
|
569 |
+
inputs4 =[img, source_prompt, target_prompt,
|
570 |
+
local, mutual,
|
571 |
+
positive_prompt, negative_prompt,
|
572 |
+
guidance_s, guidance_t,
|
573 |
+
num_inference_steps,
|
574 |
+
width, height, seed, strength,
|
575 |
+
cross_replace_steps, self_replace_steps,
|
576 |
+
thresh_e, thresh_m, denoise, user_instruct, api_key]
|
577 |
+
generate1.click(inference, inputs=inputs1, outputs=image_out)
|
578 |
+
generate3.click(inference, inputs=inputs1, outputs=image_out)
|
579 |
+
generate4.click(inference, inputs=inputs4, outputs=image_out)
|
580 |
+
|
581 |
+
ex = gr.Examples(
|
582 |
+
[
|
583 |
+
["images/corgi.jpg","corgi","cat","cat","","","",1,2,15,512,512,0,1,0.7,0.7,0.6,0.6,False],
|
584 |
+
["images/muffin.png","muffin","chihuahua","chihuahua","","","",1,2,15,512,512,0,1,0.65,0.6,0.4,0.7,False],
|
585 |
+
["images/InfEdit.jpg","an anime girl holding a pad","an anime girl holding a book","book","girl ","","",1,2,15,512,512,0,1,0.8,0.8,0.6,0.6,False],
|
586 |
+
["images/summer.jpg","a photo of summer scene","A photo of winter scene","","","","",1,2,15,512,512,0,1,1,1,0.6,0.7,False],
|
587 |
+
["images/bear.jpg","A bear sitting on the ground","A bear standing on the ground","bear","","","",1,1.5,15,512,512,0,1,0.3,0.3,0.5,0.7,False],
|
588 |
+
["images/james.jpg","a man playing basketball","a man playing soccer","soccer","man ","","",1,2,15,512,512,0,1,0,0,0.5,0.4,False],
|
589 |
+
["images/osu.jfif","A football with OSU logo","A football with Umich logo","logo","","","",1,2,15,512,512,0,1,0.5,0,0.6,0.7,False],
|
590 |
+
["images/groundhog.png","A anime groundhog head","A anime ferret head","head","","","",1,2,15,512,512,0,1,0.5,0.5,0.6,0.7,False],
|
591 |
+
["images/miku.png","A anime girl with green hair and green eyes and shirt","A anime girl with red hair and red eyes and shirt","red hair and red eyes","shirt","","",1,2,15,512,512,0,1,1,1,0.2,0.8,False],
|
592 |
+
["images/droplet.png","a blue droplet emoji with a smiling face with yellow dot","a red fire emoji with an angry face with yellow dot","","yellow dot","","",1,2,15,512,512,0,1,0.7,0.7,0.6,0.7,False],
|
593 |
+
["images/moyu.png","an emoji holding a sign and a fish","an emoji holding a sign and a shark","shark","sign","","",1,2,15,512,512,0,1,0.7,0.7,0.5,0.7,False],
|
594 |
+
["images/214000000000.jpg","a painting of a waterfall in the mountains","a painting of a waterfall and angels in the mountains","angels","","","",1,2,15,512,512,0,1,0,0,0.5,0.5,False],
|
595 |
+
["images/311000000002.jpg","a lion in a suit sitting at a table with a laptop","a lion in a suit sitting at a table with nothing","nothing","","","",1,2,15,512,512,0,1,0,0,0.5,0.5,False],
|
596 |
+
["images/genshin.png","anime girl, with blue logo","anime boy with golden hair named Link, from The Legend of Zelda, with legend of zelda logo","anime boy","","","",1,2,50,512,512,0,1,0.65,0.65,0.5,0.5,False],
|
597 |
+
["images/angry.jpg","a man with bounding boxes at the door","a man with angry birds at the door","angry birds","a man","","",1,2,15,512,512,0,1,0.3,0.1,0.45,0.4,False],
|
598 |
+
["images/Doom_Slayer.jpg","doom slayer from game doom","master chief from game halo","","","","",1,2,15,512,512,0,1,0.6,0.8,0.7,0.7,False],
|
599 |
+
["images/Elon_Musk.webp","Elon Musk in front of a car","Mark Iv iron man suit in front of a car","Mark Iv iron man suit","car","","",1,2,15,512,512,0,1,0.5,0.3,0.6,0.7,False],
|
600 |
+
["images/dragon.jpg","a mascot dragon","pixel art, a mascot dragon","","","","",1,2,25,512,512,0,1,0.7,0.7,0.6,0.6,False],
|
601 |
+
["images/frieren.jpg","a anime girl with long white hair holding a bottle","a anime girl with long white hair holding a smartphone","smartphone","","","",1,2,15,512,512,0,1,0.7,0.7,0.7,0.7,False],
|
602 |
+
["images/sam.png","a man with an openai logo","a man with a twitter logo","a twitter logo","a man","","",1,2,15,512,512,0,0.8,0,0,0.3,0.6,True],
|
603 |
+
|
604 |
+
|
605 |
+
],
|
606 |
+
[img, source_prompt, target_prompt,
|
607 |
+
local, mutual,
|
608 |
+
positive_prompt, negative_prompt,
|
609 |
+
guidance_s, guidance_t,
|
610 |
+
num_inference_steps,
|
611 |
+
width, height, seed, strength,
|
612 |
+
cross_replace_steps, self_replace_steps,
|
613 |
+
thresh_e, thresh_m, denoise],
|
614 |
+
image_out, inference, cache_examples=True,examples_per_page=20)
|
615 |
+
# if not is_colab:
|
616 |
+
# demo.queue(concurrency_count=1)
|
617 |
+
|
618 |
+
# demo.launch(debug=False, share=False,server_name="0.0.0.0",server_port = 80)
|
619 |
+
demo.launch(debug=False, share=False)
|
620 |
+
|
images/214000000000.jpg
ADDED
images/311000000002.jpg
ADDED
images/Doom_Slayer.jpg
ADDED
images/Elon_Musk.webp
ADDED
images/InfEdit.jpg
ADDED
Git LFS Details
|
images/angry.jpg
ADDED
images/bear.jpg
ADDED
images/computer.png
ADDED
images/corgi.jpg
ADDED
images/dragon.jpg
ADDED
images/droplet.png
ADDED
images/frieren.jpg
ADDED
images/genshin.png
ADDED
images/groundhog.png
ADDED
images/james.jpg
ADDED
images/miku.png
ADDED
images/moyu.png
ADDED
images/muffin.png
ADDED
images/osu.jfif
ADDED
Binary file (4.38 kB). View file
|
|
images/sam.png
ADDED
Git LFS Details
|
images/summer.jpg
ADDED
nsfw.png
ADDED
pipeline_ead.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import torch
|
7 |
+
from packaging import version
|
8 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import FrozenDict
|
11 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
12 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
13 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
14 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
15 |
+
from diffusers.schedulers import LCMScheduler
|
16 |
+
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
18 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
19 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
20 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
24 |
+
|
25 |
+
|
26 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
27 |
+
def preprocess(image):
|
28 |
+
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
|
29 |
+
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
|
30 |
+
if isinstance(image, torch.Tensor):
|
31 |
+
return image
|
32 |
+
elif isinstance(image, PIL.Image.Image):
|
33 |
+
image = [image]
|
34 |
+
|
35 |
+
if isinstance(image[0], PIL.Image.Image):
|
36 |
+
w, h = image[0].size
|
37 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
38 |
+
|
39 |
+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
40 |
+
image = np.concatenate(image, axis=0)
|
41 |
+
image = np.array(image).astype(np.float32) / 255.0
|
42 |
+
image = image.transpose(0, 3, 1, 2)
|
43 |
+
image = 2.0 * image - 1.0
|
44 |
+
image = torch.from_numpy(image)
|
45 |
+
elif isinstance(image[0], torch.Tensor):
|
46 |
+
image = torch.cat(image, dim=0)
|
47 |
+
return image
|
48 |
+
|
49 |
+
|
50 |
+
def ddcm_sampler(scheduler, x_s, x_t, timestep, e_s, e_t, x_0, noise, eta, to_next=True):
|
51 |
+
if scheduler.num_inference_steps is None:
|
52 |
+
raise ValueError(
|
53 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
54 |
+
)
|
55 |
+
|
56 |
+
if scheduler.step_index is None:
|
57 |
+
scheduler._init_step_index(timestep)
|
58 |
+
|
59 |
+
prev_step_index = scheduler.step_index + 1
|
60 |
+
if prev_step_index < len(scheduler.timesteps):
|
61 |
+
prev_timestep = scheduler.timesteps[prev_step_index]
|
62 |
+
else:
|
63 |
+
prev_timestep = timestep
|
64 |
+
|
65 |
+
alpha_prod_t = scheduler.alphas_cumprod[timestep]
|
66 |
+
alpha_prod_t_prev = (
|
67 |
+
scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
|
68 |
+
)
|
69 |
+
beta_prod_t = 1 - alpha_prod_t
|
70 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
71 |
+
variance = beta_prod_t_prev
|
72 |
+
std_dev_t = eta * variance
|
73 |
+
noise = std_dev_t ** (0.5) * noise
|
74 |
+
|
75 |
+
e_c = (x_s - alpha_prod_t ** (0.5) * x_0) / (1 - alpha_prod_t) ** (0.5)
|
76 |
+
|
77 |
+
pred_x0 = x_0 + ((x_t - x_s) - beta_prod_t ** (0.5) * (e_t - e_s)) / alpha_prod_t ** (0.5)
|
78 |
+
eps = (e_t - e_s) + e_c
|
79 |
+
dir_xt = (beta_prod_t_prev - std_dev_t) ** (0.5) * eps
|
80 |
+
|
81 |
+
# Noise is not used for one-step sampling.
|
82 |
+
if len(scheduler.timesteps) > 1:
|
83 |
+
prev_xt = alpha_prod_t_prev ** (0.5) * pred_x0 + dir_xt + noise
|
84 |
+
prev_xs = alpha_prod_t_prev ** (0.5) * x_0 + dir_xt + noise
|
85 |
+
else:
|
86 |
+
prev_xt = pred_x0
|
87 |
+
prev_xs = x_0
|
88 |
+
|
89 |
+
if to_next:
|
90 |
+
scheduler._step_index += 1
|
91 |
+
return prev_xs, prev_xt, pred_x0
|
92 |
+
|
93 |
+
|
94 |
+
class EditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
95 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
96 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
vae: AutoencoderKL,
|
101 |
+
text_encoder: CLIPTextModel,
|
102 |
+
tokenizer: CLIPTokenizer,
|
103 |
+
unet: UNet2DConditionModel,
|
104 |
+
scheduler: LCMScheduler,
|
105 |
+
safety_checker: StableDiffusionSafetyChecker,
|
106 |
+
feature_extractor: CLIPImageProcessor,
|
107 |
+
requires_safety_checker: bool = True,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
112 |
+
deprecation_message = (
|
113 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
114 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
115 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
116 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
117 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
118 |
+
" file"
|
119 |
+
)
|
120 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
121 |
+
new_config = dict(scheduler.config)
|
122 |
+
new_config["steps_offset"] = 1
|
123 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
124 |
+
|
125 |
+
if safety_checker is None and requires_safety_checker:
|
126 |
+
logger.warning(
|
127 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
128 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
129 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
130 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
131 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
132 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
133 |
+
)
|
134 |
+
|
135 |
+
if safety_checker is not None and feature_extractor is None:
|
136 |
+
raise ValueError(
|
137 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
138 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
139 |
+
)
|
140 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
141 |
+
version.parse(unet.config._diffusers_version).base_version
|
142 |
+
) < version.parse("0.9.0.dev0")
|
143 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
144 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
145 |
+
deprecation_message = (
|
146 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
147 |
+
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
148 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
149 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
150 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
151 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
152 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
153 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
154 |
+
" the `unet/config.json` file"
|
155 |
+
)
|
156 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
157 |
+
new_config = dict(unet.config)
|
158 |
+
new_config["sample_size"] = 64
|
159 |
+
unet._internal_dict = FrozenDict(new_config)
|
160 |
+
|
161 |
+
self.register_modules(
|
162 |
+
vae=vae,
|
163 |
+
text_encoder=text_encoder,
|
164 |
+
tokenizer=tokenizer,
|
165 |
+
unet=unet,
|
166 |
+
scheduler=scheduler,
|
167 |
+
safety_checker=safety_checker,
|
168 |
+
feature_extractor=feature_extractor,
|
169 |
+
)
|
170 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
171 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
172 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
173 |
+
|
174 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
175 |
+
def _encode_prompt(
|
176 |
+
self,
|
177 |
+
prompt,
|
178 |
+
device,
|
179 |
+
num_images_per_prompt,
|
180 |
+
do_classifier_free_guidance,
|
181 |
+
negative_prompt=None,
|
182 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
183 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
184 |
+
lora_scale: Optional[float] = None,
|
185 |
+
):
|
186 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
187 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
188 |
+
|
189 |
+
prompt_embeds_tuple = self.encode_prompt(
|
190 |
+
prompt=prompt,
|
191 |
+
device=device,
|
192 |
+
num_images_per_prompt=num_images_per_prompt,
|
193 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
194 |
+
negative_prompt=negative_prompt,
|
195 |
+
prompt_embeds=prompt_embeds,
|
196 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
197 |
+
lora_scale=lora_scale,
|
198 |
+
)
|
199 |
+
|
200 |
+
# concatenate for backwards comp
|
201 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
202 |
+
|
203 |
+
return prompt_embeds
|
204 |
+
|
205 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
206 |
+
def encode_prompt(
|
207 |
+
self,
|
208 |
+
prompt,
|
209 |
+
device,
|
210 |
+
num_images_per_prompt,
|
211 |
+
do_classifier_free_guidance,
|
212 |
+
negative_prompt=None,
|
213 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
214 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
215 |
+
lora_scale: Optional[float] = None,
|
216 |
+
):
|
217 |
+
# set lora scale so that monkey patched LoRA
|
218 |
+
# function of text encoder can correctly access it
|
219 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
220 |
+
self._lora_scale = lora_scale
|
221 |
+
|
222 |
+
# dynamically adjust the LoRA scale
|
223 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
224 |
+
|
225 |
+
if prompt is not None and isinstance(prompt, str):
|
226 |
+
batch_size = 1
|
227 |
+
elif prompt is not None and isinstance(prompt, list):
|
228 |
+
batch_size = len(prompt)
|
229 |
+
else:
|
230 |
+
batch_size = prompt_embeds.shape[0]
|
231 |
+
|
232 |
+
if prompt_embeds is None:
|
233 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
234 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
235 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
236 |
+
|
237 |
+
text_inputs = self.tokenizer(
|
238 |
+
prompt,
|
239 |
+
padding="max_length",
|
240 |
+
max_length=self.tokenizer.model_max_length,
|
241 |
+
truncation=True,
|
242 |
+
return_tensors="pt",
|
243 |
+
)
|
244 |
+
text_input_ids = text_inputs.input_ids
|
245 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
246 |
+
|
247 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
248 |
+
text_input_ids, untruncated_ids
|
249 |
+
):
|
250 |
+
removed_text = self.tokenizer.batch_decode(
|
251 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
252 |
+
)
|
253 |
+
logger.warning(
|
254 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
255 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
256 |
+
)
|
257 |
+
|
258 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
259 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
260 |
+
else:
|
261 |
+
attention_mask = None
|
262 |
+
|
263 |
+
prompt_embeds = self.text_encoder(
|
264 |
+
text_input_ids.to(device),
|
265 |
+
attention_mask=attention_mask,
|
266 |
+
)
|
267 |
+
prompt_embeds = prompt_embeds[0]
|
268 |
+
|
269 |
+
if self.text_encoder is not None:
|
270 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
271 |
+
elif self.unet is not None:
|
272 |
+
prompt_embeds_dtype = self.unet.dtype
|
273 |
+
else:
|
274 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
275 |
+
|
276 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
277 |
+
|
278 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
279 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
280 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
281 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
282 |
+
|
283 |
+
# get unconditional embeddings for classifier free guidance
|
284 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
285 |
+
uncond_tokens: List[str]
|
286 |
+
if negative_prompt is None:
|
287 |
+
uncond_tokens = [""] * batch_size
|
288 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
289 |
+
raise TypeError(
|
290 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
291 |
+
f" {type(prompt)}."
|
292 |
+
)
|
293 |
+
elif isinstance(negative_prompt, str):
|
294 |
+
uncond_tokens = [negative_prompt]
|
295 |
+
elif batch_size != len(negative_prompt):
|
296 |
+
raise ValueError(
|
297 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
298 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
299 |
+
" the batch size of `prompt`."
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
uncond_tokens = negative_prompt
|
303 |
+
|
304 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
305 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
306 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
307 |
+
|
308 |
+
max_length = prompt_embeds.shape[1]
|
309 |
+
uncond_input = self.tokenizer(
|
310 |
+
uncond_tokens,
|
311 |
+
padding="max_length",
|
312 |
+
max_length=max_length,
|
313 |
+
truncation=True,
|
314 |
+
return_tensors="pt",
|
315 |
+
)
|
316 |
+
|
317 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
318 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
319 |
+
else:
|
320 |
+
attention_mask = None
|
321 |
+
|
322 |
+
negative_prompt_embeds = self.text_encoder(
|
323 |
+
uncond_input.input_ids.to(device),
|
324 |
+
attention_mask=attention_mask,
|
325 |
+
)
|
326 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
327 |
+
|
328 |
+
if do_classifier_free_guidance:
|
329 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
330 |
+
seq_len = negative_prompt_embeds.shape[1]
|
331 |
+
|
332 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
333 |
+
|
334 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
335 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
336 |
+
|
337 |
+
return prompt_embeds, negative_prompt_embeds
|
338 |
+
|
339 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
|
340 |
+
def check_inputs(
|
341 |
+
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
342 |
+
):
|
343 |
+
if strength < 0 or strength > 1:
|
344 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
345 |
+
|
346 |
+
if (callback_steps is None) or (
|
347 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
348 |
+
):
|
349 |
+
raise ValueError(
|
350 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
351 |
+
f" {type(callback_steps)}."
|
352 |
+
)
|
353 |
+
|
354 |
+
if prompt is not None and prompt_embeds is not None:
|
355 |
+
raise ValueError(
|
356 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
357 |
+
" only forward one of the two."
|
358 |
+
)
|
359 |
+
elif prompt is None and prompt_embeds is None:
|
360 |
+
raise ValueError(
|
361 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
362 |
+
)
|
363 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
364 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
365 |
+
|
366 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
367 |
+
raise ValueError(
|
368 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
369 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
370 |
+
)
|
371 |
+
|
372 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
373 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
374 |
+
raise ValueError(
|
375 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
376 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
377 |
+
f" {negative_prompt_embeds.shape}."
|
378 |
+
)
|
379 |
+
|
380 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
381 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
382 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
383 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
384 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
385 |
+
# and should be between [0, 1]
|
386 |
+
|
387 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
388 |
+
extra_step_kwargs = {}
|
389 |
+
if accepts_eta:
|
390 |
+
extra_step_kwargs["eta"] = eta
|
391 |
+
|
392 |
+
# check if the scheduler accepts generator
|
393 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
394 |
+
if accepts_generator:
|
395 |
+
extra_step_kwargs["generator"] = generator
|
396 |
+
return extra_step_kwargs
|
397 |
+
|
398 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
399 |
+
def run_safety_checker(self, image, device, dtype):
|
400 |
+
if self.safety_checker is None:
|
401 |
+
has_nsfw_concept = None
|
402 |
+
else:
|
403 |
+
if torch.is_tensor(image):
|
404 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
405 |
+
else:
|
406 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
407 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
408 |
+
image, has_nsfw_concept = self.safety_checker(
|
409 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
410 |
+
)
|
411 |
+
return image, has_nsfw_concept
|
412 |
+
|
413 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
414 |
+
def decode_latents(self, latents):
|
415 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
416 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
417 |
+
|
418 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
419 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
420 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
421 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
422 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
423 |
+
return image
|
424 |
+
|
425 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
426 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
427 |
+
# get the original timestep using init_timestep
|
428 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
429 |
+
|
430 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
431 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
432 |
+
|
433 |
+
return timesteps, num_inference_steps - t_start
|
434 |
+
|
435 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, denoise_model, generator=None):
|
436 |
+
image = image.to(device=device, dtype=dtype)
|
437 |
+
|
438 |
+
batch_size = image.shape[0]
|
439 |
+
|
440 |
+
if image.shape[1] == 4:
|
441 |
+
init_latents = image
|
442 |
+
|
443 |
+
else:
|
444 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
445 |
+
raise ValueError(
|
446 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
447 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
448 |
+
)
|
449 |
+
|
450 |
+
if isinstance(generator, list):
|
451 |
+
init_latents = [
|
452 |
+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
453 |
+
]
|
454 |
+
init_latents = torch.cat(init_latents, dim=0)
|
455 |
+
else:
|
456 |
+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
457 |
+
|
458 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
459 |
+
|
460 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
461 |
+
# expand init_latents for batch_size
|
462 |
+
deprecation_message = (
|
463 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
464 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
465 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
466 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
467 |
+
)
|
468 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
469 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
470 |
+
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
471 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
472 |
+
raise ValueError(
|
473 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
477 |
+
|
478 |
+
# add noise to latents using the timestep
|
479 |
+
shape = init_latents.shape
|
480 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
481 |
+
|
482 |
+
# get latents
|
483 |
+
clean_latents = init_latents
|
484 |
+
if denoise_model:
|
485 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
486 |
+
latents = init_latents
|
487 |
+
else:
|
488 |
+
latents = noise
|
489 |
+
|
490 |
+
return latents, clean_latents
|
491 |
+
|
492 |
+
@torch.no_grad()
|
493 |
+
def __call__(
|
494 |
+
self,
|
495 |
+
prompt: Union[str, List[str]],
|
496 |
+
source_prompt: Union[str, List[str]],
|
497 |
+
negative_prompt: Union[str, List[str]]=None,
|
498 |
+
positive_prompt: Union[str, List[str]]=None,
|
499 |
+
image: PipelineImageInput = None,
|
500 |
+
strength: float = 0.8,
|
501 |
+
num_inference_steps: Optional[int] = 50,
|
502 |
+
original_inference_steps: Optional[int] = 50,
|
503 |
+
guidance_scale: Optional[float] = 7.5,
|
504 |
+
source_guidance_scale: Optional[float] = 1,
|
505 |
+
num_images_per_prompt: Optional[int] = 1,
|
506 |
+
eta: Optional[float] = 1.0,
|
507 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
508 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
509 |
+
output_type: Optional[str] = "pil",
|
510 |
+
return_dict: bool = True,
|
511 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
512 |
+
callback_steps: int = 1,
|
513 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
514 |
+
denoise_model: Optional[bool] = True,
|
515 |
+
):
|
516 |
+
# 1. Check inputs
|
517 |
+
self.check_inputs(prompt, strength, callback_steps)
|
518 |
+
|
519 |
+
# 2. Define call parameters
|
520 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
521 |
+
device = self._execution_device
|
522 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
523 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
524 |
+
# corresponds to doing no classifier free guidance.
|
525 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
526 |
+
|
527 |
+
# 3. Encode input prompt
|
528 |
+
text_encoder_lora_scale = (
|
529 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
530 |
+
)
|
531 |
+
prompt_embeds_tuple = self.encode_prompt(
|
532 |
+
prompt,
|
533 |
+
device,
|
534 |
+
num_images_per_prompt,
|
535 |
+
do_classifier_free_guidance,
|
536 |
+
negative_prompt=negative_prompt,
|
537 |
+
prompt_embeds=prompt_embeds,
|
538 |
+
lora_scale=text_encoder_lora_scale,
|
539 |
+
)
|
540 |
+
source_prompt_embeds_tuple = self.encode_prompt(
|
541 |
+
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, positive_prompt, None
|
542 |
+
)
|
543 |
+
if prompt_embeds_tuple[1] is not None:
|
544 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
545 |
+
else:
|
546 |
+
prompt_embeds = prompt_embeds_tuple[0]
|
547 |
+
if source_prompt_embeds_tuple[1] is not None:
|
548 |
+
source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]])
|
549 |
+
else:
|
550 |
+
source_prompt_embeds = source_prompt_embeds_tuple[0]
|
551 |
+
|
552 |
+
# 4. Preprocess image
|
553 |
+
image = self.image_processor.preprocess(image)
|
554 |
+
|
555 |
+
# 5. Prepare timesteps
|
556 |
+
self.scheduler.set_timesteps(
|
557 |
+
num_inference_steps=num_inference_steps,
|
558 |
+
device=device,
|
559 |
+
original_inference_steps=original_inference_steps)
|
560 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
561 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
562 |
+
|
563 |
+
# 6. Prepare latent variables
|
564 |
+
latents, clean_latents = self.prepare_latents(
|
565 |
+
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, denoise_model, generator
|
566 |
+
)
|
567 |
+
source_latents = latents
|
568 |
+
mutual_latents = latents
|
569 |
+
|
570 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
571 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
572 |
+
generator = extra_step_kwargs.pop("generator", None)
|
573 |
+
|
574 |
+
# 8. Denoising loop
|
575 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
576 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
577 |
+
for i, t in enumerate(timesteps):
|
578 |
+
# expand the latents if we are doing classifier free guidance
|
579 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
580 |
+
source_latent_model_input = (
|
581 |
+
torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents
|
582 |
+
)
|
583 |
+
mutual_latent_model_input = (
|
584 |
+
torch.cat([mutual_latents] * 2) if do_classifier_free_guidance else mutual_latents
|
585 |
+
)
|
586 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
587 |
+
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
|
588 |
+
mutual_latent_model_input = self.scheduler.scale_model_input(mutual_latent_model_input, t)
|
589 |
+
|
590 |
+
# predict the noise residual
|
591 |
+
if do_classifier_free_guidance:
|
592 |
+
concat_latent_model_input = torch.stack(
|
593 |
+
[
|
594 |
+
source_latent_model_input[0],
|
595 |
+
latent_model_input[0],
|
596 |
+
mutual_latent_model_input[0],
|
597 |
+
source_latent_model_input[1],
|
598 |
+
latent_model_input[1],
|
599 |
+
mutual_latent_model_input[1],
|
600 |
+
],
|
601 |
+
dim=0,
|
602 |
+
)
|
603 |
+
concat_prompt_embeds = torch.stack(
|
604 |
+
[
|
605 |
+
source_prompt_embeds[0],
|
606 |
+
prompt_embeds[0],
|
607 |
+
source_prompt_embeds[0],
|
608 |
+
source_prompt_embeds[1],
|
609 |
+
prompt_embeds[1],
|
610 |
+
source_prompt_embeds[1],
|
611 |
+
],
|
612 |
+
dim=0,
|
613 |
+
)
|
614 |
+
else:
|
615 |
+
concat_latent_model_input = torch.cat(
|
616 |
+
[
|
617 |
+
source_latent_model_input,
|
618 |
+
latent_model_input,
|
619 |
+
mutual_latent_model_input,
|
620 |
+
],
|
621 |
+
dim=0,
|
622 |
+
)
|
623 |
+
concat_prompt_embeds = torch.cat(
|
624 |
+
[
|
625 |
+
source_prompt_embeds,
|
626 |
+
prompt_embeds,
|
627 |
+
source_prompt_embeds,
|
628 |
+
],
|
629 |
+
dim=0,
|
630 |
+
)
|
631 |
+
|
632 |
+
concat_noise_pred = self.unet(
|
633 |
+
concat_latent_model_input,
|
634 |
+
t,
|
635 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
636 |
+
encoder_hidden_states=concat_prompt_embeds,
|
637 |
+
).sample
|
638 |
+
|
639 |
+
# perform guidance
|
640 |
+
if do_classifier_free_guidance:
|
641 |
+
(
|
642 |
+
source_noise_pred_uncond,
|
643 |
+
noise_pred_uncond,
|
644 |
+
mutual_noise_pred_uncond,
|
645 |
+
source_noise_pred_text,
|
646 |
+
noise_pred_text,
|
647 |
+
mutual_noise_pred_text
|
648 |
+
) = concat_noise_pred.chunk(6, dim=0)
|
649 |
+
|
650 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
651 |
+
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
|
652 |
+
source_noise_pred_text - source_noise_pred_uncond
|
653 |
+
)
|
654 |
+
mutual_noise_pred = mutual_noise_pred_uncond + source_guidance_scale * (
|
655 |
+
mutual_noise_pred_text - mutual_noise_pred_uncond
|
656 |
+
)
|
657 |
+
|
658 |
+
else:
|
659 |
+
(source_noise_pred, noise_pred, mutual_noise_pred) = concat_noise_pred.chunk(3, dim=0)
|
660 |
+
|
661 |
+
noise = torch.randn(
|
662 |
+
latents.shape, dtype=latents.dtype, device=latents.device, generator=generator
|
663 |
+
)
|
664 |
+
|
665 |
+
_, latents, pred_x0 = ddcm_sampler(
|
666 |
+
self.scheduler, source_latents,
|
667 |
+
latents, t,
|
668 |
+
source_noise_pred, noise_pred,
|
669 |
+
clean_latents, noise=noise,
|
670 |
+
eta=eta, to_next=False,
|
671 |
+
**extra_step_kwargs
|
672 |
+
)
|
673 |
+
|
674 |
+
source_latents, mutual_latents, pred_xm = ddcm_sampler(
|
675 |
+
self.scheduler, source_latents,
|
676 |
+
mutual_latents, t,
|
677 |
+
source_noise_pred, mutual_noise_pred,
|
678 |
+
clean_latents, noise=noise,
|
679 |
+
eta=eta, **extra_step_kwargs
|
680 |
+
)
|
681 |
+
|
682 |
+
# call the callback, if provided
|
683 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
684 |
+
progress_bar.update()
|
685 |
+
if callback is not None and i % callback_steps == 0:
|
686 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
687 |
+
mutual_latents, latents = callback(i, t, source_latents, latents, mutual_latents, alpha_prod_t)
|
688 |
+
|
689 |
+
# 9. Post-processing
|
690 |
+
if not output_type == "latent":
|
691 |
+
image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
|
692 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
693 |
+
else:
|
694 |
+
image = pred_x0
|
695 |
+
has_nsfw_concept = None
|
696 |
+
|
697 |
+
if has_nsfw_concept is None:
|
698 |
+
do_denormalize = [True] * image.shape[0]
|
699 |
+
else:
|
700 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
701 |
+
|
702 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
703 |
+
|
704 |
+
if not return_dict:
|
705 |
+
return (image, has_nsfw_concept)
|
706 |
+
|
707 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
ptp_utils.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from typing import Optional, Union, Tuple, Dict
|
18 |
+
from PIL import Image
|
19 |
+
|
20 |
+
def save_images(images,dest, num_rows=1, offset_ratio=0.02):
|
21 |
+
if type(images) is list:
|
22 |
+
num_empty = len(images) % num_rows
|
23 |
+
elif images.ndim == 4:
|
24 |
+
num_empty = images.shape[0] % num_rows
|
25 |
+
else:
|
26 |
+
images = [images]
|
27 |
+
num_empty = 0
|
28 |
+
|
29 |
+
pil_img = Image.fromarray(images[-1])
|
30 |
+
pil_img.save(dest)
|
31 |
+
# display(pil_img)
|
32 |
+
|
33 |
+
|
34 |
+
def save_image(images,dest, num_rows=1, offset_ratio=0.02):
|
35 |
+
print(images.shape)
|
36 |
+
pil_img = Image.fromarray(images[0])
|
37 |
+
pil_img.save(dest)
|
38 |
+
|
39 |
+
def register_attention_control(model, controller):
|
40 |
+
class AttnProcessor():
|
41 |
+
def __init__(self,place_in_unet):
|
42 |
+
self.place_in_unet = place_in_unet
|
43 |
+
|
44 |
+
def __call__(self,
|
45 |
+
attn,
|
46 |
+
hidden_states,
|
47 |
+
encoder_hidden_states=None,
|
48 |
+
attention_mask=None,
|
49 |
+
temb=None,
|
50 |
+
scale=1.0,):
|
51 |
+
# The `Attention` class can call different attention processors / attention functions
|
52 |
+
|
53 |
+
residual = hidden_states
|
54 |
+
|
55 |
+
if attn.spatial_norm is not None:
|
56 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
57 |
+
|
58 |
+
input_ndim = hidden_states.ndim
|
59 |
+
|
60 |
+
if input_ndim == 4:
|
61 |
+
batch_size, channel, height, width = hidden_states.shape
|
62 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
63 |
+
|
64 |
+
h = attn.heads
|
65 |
+
is_cross = encoder_hidden_states is not None
|
66 |
+
if encoder_hidden_states is None:
|
67 |
+
encoder_hidden_states = hidden_states
|
68 |
+
elif attn.norm_cross:
|
69 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
70 |
+
|
71 |
+
batch_size, sequence_length, _ = (
|
72 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
73 |
+
)
|
74 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
75 |
+
|
76 |
+
q = attn.to_q(hidden_states)
|
77 |
+
k = attn.to_k(encoder_hidden_states)
|
78 |
+
v = attn.to_v(encoder_hidden_states)
|
79 |
+
q = attn.head_to_batch_dim(q)
|
80 |
+
k = attn.head_to_batch_dim(k)
|
81 |
+
v = attn.head_to_batch_dim(v)
|
82 |
+
|
83 |
+
if not is_cross:
|
84 |
+
q,k,v = controller.self_attn_forward(q, k, v, attn.heads)
|
85 |
+
|
86 |
+
attention_probs = attn.get_attention_scores(q, k, attention_mask)
|
87 |
+
if is_cross:
|
88 |
+
attention_probs = controller(attention_probs , is_cross, self.place_in_unet)
|
89 |
+
# else:
|
90 |
+
# out = controller.self_attn_forward(q, k, v, sim, attention_probs , is_cross, self.place_in_unet, attn.heads, scale=attn.scale)
|
91 |
+
hidden_states = torch.bmm(attention_probs, v)
|
92 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
93 |
+
|
94 |
+
# linear proj
|
95 |
+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
96 |
+
# dropout
|
97 |
+
hidden_states = attn.to_out[1](hidden_states)
|
98 |
+
|
99 |
+
if input_ndim == 4:
|
100 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
101 |
+
|
102 |
+
if attn.residual_connection:
|
103 |
+
hidden_states = hidden_states + residual
|
104 |
+
|
105 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
106 |
+
|
107 |
+
return hidden_states
|
108 |
+
|
109 |
+
|
110 |
+
def register_recr(net_, count, place_in_unet):
|
111 |
+
for idx, m in enumerate(net_.modules()):
|
112 |
+
# print(m.__class__.__name__)
|
113 |
+
if m.__class__.__name__ == "Attention":
|
114 |
+
count+=1
|
115 |
+
m.processor = AttnProcessor( place_in_unet)
|
116 |
+
return count
|
117 |
+
|
118 |
+
cross_att_count = 0
|
119 |
+
sub_nets = model.unet.named_children()
|
120 |
+
for net in sub_nets:
|
121 |
+
if "down" in net[0]:
|
122 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
123 |
+
elif "up" in net[0]:
|
124 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
125 |
+
elif "mid" in net[0]:
|
126 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
127 |
+
controller.num_att_layers = cross_att_count
|
128 |
+
|
129 |
+
|
130 |
+
def get_word_inds(text: str, word_place: int, tokenizer):
|
131 |
+
split_text = text.split(" ")
|
132 |
+
if type(word_place) is str:
|
133 |
+
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
134 |
+
elif type(word_place) is int:
|
135 |
+
word_place = [word_place]
|
136 |
+
out = []
|
137 |
+
if len(word_place) > 0:
|
138 |
+
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
139 |
+
cur_len, ptr = 0, 0
|
140 |
+
|
141 |
+
for i in range(len(words_encode)):
|
142 |
+
cur_len += len(words_encode[i])
|
143 |
+
if ptr in word_place:
|
144 |
+
out.append(i + 1)
|
145 |
+
if cur_len >= len(split_text[ptr]):
|
146 |
+
ptr += 1
|
147 |
+
cur_len = 0
|
148 |
+
return np.array(out)
|
149 |
+
|
150 |
+
|
151 |
+
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
|
152 |
+
if type(bounds) is float:
|
153 |
+
bounds = 0, bounds
|
154 |
+
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
|
155 |
+
if word_inds is None:
|
156 |
+
word_inds = torch.arange(alpha.shape[2])
|
157 |
+
alpha[: start, prompt_ind, word_inds] = 0
|
158 |
+
alpha[start: end, prompt_ind, word_inds] = 1
|
159 |
+
alpha[end:, prompt_ind, word_inds] = 0
|
160 |
+
return alpha
|
161 |
+
|
162 |
+
|
163 |
+
def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
164 |
+
tokenizer, max_num_words=77):
|
165 |
+
if type(cross_replace_steps) is not dict:
|
166 |
+
cross_replace_steps = {"default_": cross_replace_steps}
|
167 |
+
if "default_" not in cross_replace_steps:
|
168 |
+
cross_replace_steps["default_"] = (0., 1.)
|
169 |
+
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
|
170 |
+
for i in range(len(prompts) - 1):
|
171 |
+
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
|
172 |
+
i)
|
173 |
+
for key, item in cross_replace_steps.items():
|
174 |
+
if key != "default_":
|
175 |
+
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
|
176 |
+
for i, ind in enumerate(inds):
|
177 |
+
if len(ind) > 0:
|
178 |
+
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
|
179 |
+
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
|
180 |
+
return alpha_time_words
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
git+https://github.com/huggingface/diffusers.git
|
5 |
+
Pillow
|
6 |
+
transformers
|
7 |
+
opencv-python
|
8 |
+
openai
|
seq_aligner.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ScoreParams:
|
8 |
+
|
9 |
+
def __init__(self, gap, match, mismatch):
|
10 |
+
self.gap = gap
|
11 |
+
self.match = match
|
12 |
+
self.mismatch = mismatch
|
13 |
+
|
14 |
+
def mis_match_char(self, x, y):
|
15 |
+
if x != y:
|
16 |
+
return self.mismatch
|
17 |
+
else:
|
18 |
+
return self.match
|
19 |
+
|
20 |
+
|
21 |
+
def get_matrix(size_x, size_y, gap):
|
22 |
+
matrix = []
|
23 |
+
for i in range(len(size_x) + 1):
|
24 |
+
sub_matrix = []
|
25 |
+
for j in range(len(size_y) + 1):
|
26 |
+
sub_matrix.append(0)
|
27 |
+
matrix.append(sub_matrix)
|
28 |
+
for j in range(1, len(size_y) + 1):
|
29 |
+
matrix[0][j] = j*gap
|
30 |
+
for i in range(1, len(size_x) + 1):
|
31 |
+
matrix[i][0] = i*gap
|
32 |
+
return matrix
|
33 |
+
|
34 |
+
|
35 |
+
def get_matrix(size_x, size_y, gap):
|
36 |
+
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
|
37 |
+
matrix[0, 1:] = (np.arange(size_y) + 1) * gap
|
38 |
+
matrix[1:, 0] = (np.arange(size_x) + 1) * gap
|
39 |
+
return matrix
|
40 |
+
|
41 |
+
|
42 |
+
def get_traceback_matrix(size_x, size_y):
|
43 |
+
matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
|
44 |
+
matrix[0, 1:] = 1
|
45 |
+
matrix[1:, 0] = 2
|
46 |
+
matrix[0, 0] = 4
|
47 |
+
return matrix
|
48 |
+
|
49 |
+
|
50 |
+
def global_align(x, y, score):
|
51 |
+
matrix = get_matrix(len(x), len(y), score.gap)
|
52 |
+
trace_back = get_traceback_matrix(len(x), len(y))
|
53 |
+
for i in range(1, len(x) + 1):
|
54 |
+
for j in range(1, len(y) + 1):
|
55 |
+
left = matrix[i, j - 1] + score.gap
|
56 |
+
up = matrix[i - 1, j] + score.gap
|
57 |
+
diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
|
58 |
+
matrix[i, j] = max(left, up, diag)
|
59 |
+
if matrix[i, j] == left:
|
60 |
+
trace_back[i, j] = 1
|
61 |
+
elif matrix[i, j] == up:
|
62 |
+
trace_back[i, j] = 2
|
63 |
+
else:
|
64 |
+
trace_back[i, j] = 3
|
65 |
+
return matrix, trace_back
|
66 |
+
|
67 |
+
|
68 |
+
def get_aligned_sequences(x, y, trace_back):
|
69 |
+
x_seq = []
|
70 |
+
y_seq = []
|
71 |
+
i = len(x)
|
72 |
+
j = len(y)
|
73 |
+
mapper_y_to_x = []
|
74 |
+
while i > 0 or j > 0:
|
75 |
+
if trace_back[i, j] == 3:
|
76 |
+
x_seq.append(x[i-1])
|
77 |
+
y_seq.append(y[j-1])
|
78 |
+
i = i-1
|
79 |
+
j = j-1
|
80 |
+
mapper_y_to_x.append((j, i))
|
81 |
+
elif trace_back[i][j] == 1:
|
82 |
+
x_seq.append('-')
|
83 |
+
y_seq.append(y[j-1])
|
84 |
+
j = j-1
|
85 |
+
mapper_y_to_x.append((j, -1))
|
86 |
+
elif trace_back[i][j] == 2:
|
87 |
+
x_seq.append(x[i-1])
|
88 |
+
y_seq.append('-')
|
89 |
+
i = i-1
|
90 |
+
elif trace_back[i][j] == 4:
|
91 |
+
break
|
92 |
+
mapper_y_to_x.reverse()
|
93 |
+
return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
|
94 |
+
|
95 |
+
|
96 |
+
def get_mapper(x: str, y: str, specifier, tokenizer, encoder, device, max_len=77):
|
97 |
+
locol_prompt, mutual_prompt = specifier
|
98 |
+
x_seq = tokenizer.encode(x)
|
99 |
+
y_seq = tokenizer.encode(y)
|
100 |
+
e_seq = tokenizer.encode(locol_prompt)
|
101 |
+
m_seq = tokenizer.encode(mutual_prompt)
|
102 |
+
score = ScoreParams(0, 1, -1)
|
103 |
+
matrix, trace_back = global_align(x_seq, y_seq, score)
|
104 |
+
mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
|
105 |
+
alphas = torch.ones(max_len)
|
106 |
+
alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
|
107 |
+
mapper = torch.zeros(max_len, dtype=torch.int64)
|
108 |
+
mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
|
109 |
+
mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
|
110 |
+
m = copy.deepcopy(alphas)
|
111 |
+
alpha_e = torch.zeros_like(alphas)
|
112 |
+
alpha_m = torch.zeros_like(alphas)
|
113 |
+
|
114 |
+
# print("mapper of")
|
115 |
+
# print("<begin> "+x+" <end>")
|
116 |
+
# print("<begin> "+y+" <end>")
|
117 |
+
# print(mapper[:len(y_seq)])
|
118 |
+
# print(alphas[:len(y_seq)])
|
119 |
+
|
120 |
+
x = tokenizer(
|
121 |
+
x,
|
122 |
+
padding="max_length",
|
123 |
+
max_length=max_len,
|
124 |
+
truncation=True,
|
125 |
+
return_tensors="pt",
|
126 |
+
).input_ids.to(device)
|
127 |
+
y = tokenizer(
|
128 |
+
y,
|
129 |
+
padding="max_length",
|
130 |
+
max_length=max_len,
|
131 |
+
truncation=True,
|
132 |
+
return_tensors="pt",
|
133 |
+
).input_ids.to(device)
|
134 |
+
|
135 |
+
x_latent = encoder(x)[0].squeeze(0)
|
136 |
+
y_latent = encoder(y)[0].squeeze(0)
|
137 |
+
i = 0
|
138 |
+
while i<len(y_seq):
|
139 |
+
start = None
|
140 |
+
if alphas[i] == 0:
|
141 |
+
start = i
|
142 |
+
while alphas[i] == 0:
|
143 |
+
i += 1
|
144 |
+
max_sim = float('-inf')
|
145 |
+
max_s = None
|
146 |
+
max_t = None
|
147 |
+
for i_target in range(start, i):
|
148 |
+
for i_source in range(mapper[start-1]+1, mapper[i]):
|
149 |
+
sim = F.cosine_similarity(x_latent[i_target], y_latent[i_source], dim=0)
|
150 |
+
if sim > max_sim:
|
151 |
+
max_sim = sim
|
152 |
+
max_s = i_source
|
153 |
+
max_t = i_target
|
154 |
+
if max_s is not None:
|
155 |
+
mapper[max_t] = max_s
|
156 |
+
alphas[max_t] = 1
|
157 |
+
for t in e_seq:
|
158 |
+
if x_seq[max_s] == t:
|
159 |
+
alpha_e[max_t] = 1
|
160 |
+
i += 1
|
161 |
+
|
162 |
+
# replace_alpha, replace_mapper = get_replace_inds(x_seq, y_seq, m_seq, m_seq)
|
163 |
+
# if replace_mapper != []:
|
164 |
+
# mapper[replace_alpha]=torch.tensor(replace_mapper,device=mapper.device)
|
165 |
+
# alpha_m[replace_alpha]=1
|
166 |
+
|
167 |
+
i = 1
|
168 |
+
j = 1
|
169 |
+
while (i < len(y_seq)-1) and (j < len(e_seq)-1):
|
170 |
+
found = True
|
171 |
+
while e_seq[j] != y_seq[i]:
|
172 |
+
i = i + 1
|
173 |
+
if i >= len(y_seq)-1:
|
174 |
+
print("blend word not found!")
|
175 |
+
found = False
|
176 |
+
break
|
177 |
+
raise ValueError("local prompt not found in target prompt")
|
178 |
+
if found:
|
179 |
+
alpha_e[i] = 1
|
180 |
+
j = j + 1
|
181 |
+
|
182 |
+
i = 1
|
183 |
+
j = 1
|
184 |
+
while (i < len(y_seq)-1) and (j < len(m_seq)-1):
|
185 |
+
while m_seq[j] != y_seq[i]:
|
186 |
+
i = i + 1
|
187 |
+
if m_seq[j] == x_seq[mapper[i]]:
|
188 |
+
alpha_m[i] = 1
|
189 |
+
j = j + 1
|
190 |
+
else:
|
191 |
+
raise ValueError("mutual prompt not found in target prompt")
|
192 |
+
|
193 |
+
# print("fixed mapper:")
|
194 |
+
# print(mapper[:len(y_seq)])
|
195 |
+
# print(alphas[:len(y_seq)])
|
196 |
+
# print(m[:len(y_seq)])
|
197 |
+
# print(alpha_e[:len(y_seq)])
|
198 |
+
# print(alpha_m[:len(y_seq)])
|
199 |
+
return mapper, alphas, m, alpha_e, alpha_m
|
200 |
+
|
201 |
+
|
202 |
+
def get_refinement_mapper(prompts, specifiers, tokenizer, encoder, device, max_len=77):
|
203 |
+
x_seq = prompts[0]
|
204 |
+
mappers, alphas, ms, alpha_objs, alpha_descs = [], [], [], [], []
|
205 |
+
for i in range(1, len(prompts)):
|
206 |
+
mapper, alpha, m, alpha_obj, alpha_desc = get_mapper(x_seq, prompts[i], specifiers[i-1], tokenizer, encoder, device, max_len)
|
207 |
+
mappers.append(mapper)
|
208 |
+
alphas.append(alpha)
|
209 |
+
ms.append(m)
|
210 |
+
alpha_objs.append(alpha_obj)
|
211 |
+
alpha_descs.append(alpha_desc)
|
212 |
+
return torch.stack(mappers), torch.stack(alphas), torch.stack(ms), torch.stack(alpha_objs), torch.stack(alpha_descs)
|
213 |
+
|
214 |
+
|
215 |
+
def get_replace_inds(x_seq,y_seq,source_replace_seq,target_replace_seq):
|
216 |
+
replace_mapper=[]
|
217 |
+
replace_alpha=[]
|
218 |
+
source_found=False
|
219 |
+
source_match,target_match=[],[]
|
220 |
+
for j in range(len(x_seq)):
|
221 |
+
found=True
|
222 |
+
for i in range(1,len(source_replace_seq)-1):
|
223 |
+
if x_seq[j+i-1]!=source_replace_seq[i]:
|
224 |
+
found=False
|
225 |
+
break
|
226 |
+
if found:
|
227 |
+
source_found=True
|
228 |
+
for i in range(1,len(source_replace_seq)-1):
|
229 |
+
source_match.append(j+i-1)
|
230 |
+
for j in range(len(y_seq)):
|
231 |
+
found=True
|
232 |
+
for i in range(1,len(target_replace_seq)-1):
|
233 |
+
if y_seq[j+i-1]!=target_replace_seq[i]:
|
234 |
+
found=False
|
235 |
+
break
|
236 |
+
if found:
|
237 |
+
for i in range(1,len(source_replace_seq)-1):
|
238 |
+
target_match.append(j+i-1)
|
239 |
+
if not source_found:
|
240 |
+
raise ValueError("replacing object not found in prompt")
|
241 |
+
if (len(source_match)!=len(target_match)):
|
242 |
+
raise ValueError(f"the replacement word number doesn't match for word {i}!")
|
243 |
+
replace_alpha+=source_match
|
244 |
+
replace_mapper+=target_match
|
245 |
+
return replace_alpha,replace_mapper
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
def get_word_inds(text: str, word_place: int, tokenizer):
|
250 |
+
split_text = text.split(" ")
|
251 |
+
if type(word_place) is str:
|
252 |
+
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
253 |
+
elif type(word_place) is int:
|
254 |
+
word_place = [word_place]
|
255 |
+
out = []
|
256 |
+
if len(word_place) > 0:
|
257 |
+
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
|
258 |
+
cur_len, ptr = 0, 0
|
259 |
+
|
260 |
+
for i in range(len(words_encode)):
|
261 |
+
cur_len += len(words_encode[i])
|
262 |
+
if ptr in word_place:
|
263 |
+
out.append(i + 1)
|
264 |
+
if cur_len >= len(split_text[ptr]):
|
265 |
+
ptr += 1
|
266 |
+
cur_len = 0
|
267 |
+
return np.array(out)
|
268 |
+
|
269 |
+
|
270 |
+
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
|
271 |
+
words_x = x.split(' ')
|
272 |
+
words_y = y.split(' ')
|
273 |
+
if len(words_x) != len(words_y):
|
274 |
+
raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
|
275 |
+
f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
|
276 |
+
inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
|
277 |
+
inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
|
278 |
+
inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
|
279 |
+
mapper = np.zeros((max_len, max_len))
|
280 |
+
i = j = 0
|
281 |
+
cur_inds = 0
|
282 |
+
while i < max_len and j < max_len:
|
283 |
+
if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
|
284 |
+
inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
|
285 |
+
if len(inds_source_) == len(inds_target_):
|
286 |
+
mapper[inds_source_, inds_target_] = 1
|
287 |
+
else:
|
288 |
+
ratio = 1 / len(inds_target_)
|
289 |
+
for i_t in inds_target_:
|
290 |
+
mapper[inds_source_, i_t] = ratio
|
291 |
+
cur_inds += 1
|
292 |
+
i += len(inds_source_)
|
293 |
+
j += len(inds_target_)
|
294 |
+
elif cur_inds < len(inds_source):
|
295 |
+
mapper[i, j] = 1
|
296 |
+
i += 1
|
297 |
+
j += 1
|
298 |
+
else:
|
299 |
+
mapper[j, j] = 1
|
300 |
+
i += 1
|
301 |
+
j += 1
|
302 |
+
|
303 |
+
return torch.from_numpy(mapper).float()
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
def get_replacement_mapper(prompts, tokenizer, max_len=77):
|
308 |
+
x_seq = prompts[0]
|
309 |
+
mappers = []
|
310 |
+
for i in range(1, len(prompts)):
|
311 |
+
mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
|
312 |
+
mappers.append(mapper)
|
313 |
+
return torch.stack(mappers)
|
314 |
+
|
utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def is_google_colab():
|
2 |
+
try:
|
3 |
+
import google.colab
|
4 |
+
return True
|
5 |
+
except:
|
6 |
+
return False
|