Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as nnf | |
import abc | |
import math | |
from torchvision.utils import save_image | |
LOW_RESOURCE = False | |
MAX_NUM_WORDS = 77 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def start_att_layers(self): | |
return self.start_ac_layer #if LOW_RESOURCE else 0 | |
def end_att_layers(self): | |
return self.end_ac_layer | |
def forward(self, q, k, v, num_heads,attn): | |
raise NotImplementedError | |
def attn_forward(self, q, k, v, num_heads,attention_probs,attn): | |
if q.shape[0]//num_heads == 3: | |
h_s_re = self.forward(q, k, v, num_heads,attention_probs, attn) | |
else: | |
uq,cq = q.chunk(2) | |
uk,ck = k.chunk(2) | |
uv,cv = v.chunk(2) | |
u_attn, c_attn = attention_probs.chunk(2) | |
u_h_s_re = self.forward(uq, uk, uv, num_heads,u_attn, attn) | |
c_h_s_re = self.forward(cq, ck, cv, num_heads,c_attn, attn) | |
h_s_re = (u_h_s_re, c_h_s_re) | |
return h_s_re | |
def __call__(self, q, k, v, num_heads,attention_probs,attn): | |
if self.cur_att_layer >= self.start_att_layers and self.cur_att_layer < self.end_att_layers: | |
h_s_re = self.attn_forward(q, k, v, num_heads,attention_probs,attn) | |
else: | |
h_s_re=None | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers // 2: #+ self.num_uncond_att_layers: | |
self.cur_att_layer = 0 #self.num_uncond_att_layers | |
self.cur_step += 1 | |
self.between_steps() | |
return h_s_re | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor: | |
""" Compute the attention map contrasting. """ | |
mean_feat = tensor.mean(dim=-1, keepdims=True) | |
adjusted_tensor = (tensor - mean_feat) * contrast_factor + mean_feat | |
return adjusted_tensor | |
class AttentionStyle(AttentionControl): | |
def __init__(self, | |
num_steps, | |
start_ac_layer, end_ac_layer, | |
style_guidance=0.3, | |
mix_q_scale=1.0, | |
de_bug=False, | |
): | |
super(AttentionStyle, self).__init__() | |
self.start_ac_layer = start_ac_layer | |
self.end_ac_layer = end_ac_layer | |
self.num_steps=num_steps | |
self.de_bug = de_bug | |
self.style_guidance = style_guidance | |
self.coef = None | |
self.mix_q_scale = mix_q_scale | |
def forward(self, q, k, v, num_heads, attention_probs, attn): | |
if self.de_bug: | |
import pdb; pdb.set_trace() | |
if self.mix_q_scale < 1.0: | |
q[num_heads*2:] = q[num_heads*2:] * self.mix_q_scale + (1 - self.mix_q_scale) * q[num_heads*1:num_heads*2] | |
b,n,d = k.shape | |
re_q = q[num_heads*2:] # b,n,d, | |
re_k = torch.cat([k[num_heads*1:num_heads*2],k[num_heads*0:num_heads*1]],dim=1) #b,2n,d | |
v_re = torch.cat([v[num_heads*1:num_heads*2],v[num_heads*0:num_heads*1]],dim=1) #b,2n,d | |
re_sim = torch.bmm(re_q, re_k.transpose(-1, -2)) * attn.scale | |
re_sim[:,:,n:] = re_sim[:,:,n:] * self.style_guidance | |
re_attention_map = re_sim.softmax(-1) | |
h_s_re = torch.bmm(re_attention_map, v_re) | |
return h_s_re | |