Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,793 Bytes
a6cec16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn.functional as nnf
import abc
import math
from torchvision.utils import save_image
LOW_RESOURCE = False
MAX_NUM_WORDS = 77
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def start_att_layers(self):
return self.start_ac_layer #if LOW_RESOURCE else 0
@property
def end_att_layers(self):
return self.end_ac_layer
@abc.abstractmethod
def forward(self, q, k, v, num_heads,attn):
raise NotImplementedError
def attn_forward(self, q, k, v, num_heads,attention_probs,attn):
if q.shape[0]//num_heads == 3:
h_s_re = self.forward(q, k, v, num_heads,attention_probs, attn)
else:
uq,cq = q.chunk(2)
uk,ck = k.chunk(2)
uv,cv = v.chunk(2)
u_attn, c_attn = attention_probs.chunk(2)
u_h_s_re = self.forward(uq, uk, uv, num_heads,u_attn, attn)
c_h_s_re = self.forward(cq, ck, cv, num_heads,c_attn, attn)
h_s_re = (u_h_s_re, c_h_s_re)
return h_s_re
def __call__(self, q, k, v, num_heads,attention_probs,attn):
if self.cur_att_layer >= self.start_att_layers and self.cur_att_layer < self.end_att_layers:
h_s_re = self.attn_forward(q, k, v, num_heads,attention_probs,attn)
else:
h_s_re=None
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers // 2: #+ self.num_uncond_att_layers:
self.cur_att_layer = 0 #self.num_uncond_att_layers
self.cur_step += 1
self.between_steps()
return h_s_re
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor:
""" Compute the attention map contrasting. """
mean_feat = tensor.mean(dim=-1, keepdims=True)
adjusted_tensor = (tensor - mean_feat) * contrast_factor + mean_feat
return adjusted_tensor
class AttentionStyle(AttentionControl):
def __init__(self,
num_steps,
start_ac_layer, end_ac_layer,
style_guidance=0.3,
mix_q_scale=1.0,
de_bug=False,
):
super(AttentionStyle, self).__init__()
self.start_ac_layer = start_ac_layer
self.end_ac_layer = end_ac_layer
self.num_steps=num_steps
self.de_bug = de_bug
self.style_guidance = style_guidance
self.coef = None
self.mix_q_scale = mix_q_scale
def forward(self, q, k, v, num_heads, attention_probs, attn):
if self.de_bug:
import pdb; pdb.set_trace()
if self.mix_q_scale < 1.0:
q[num_heads*2:] = q[num_heads*2:] * self.mix_q_scale + (1 - self.mix_q_scale) * q[num_heads*1:num_heads*2]
b,n,d = k.shape
re_q = q[num_heads*2:] # b,n,d,
re_k = torch.cat([k[num_heads*1:num_heads*2],k[num_heads*0:num_heads*1]],dim=1) #b,2n,d
v_re = torch.cat([v[num_heads*1:num_heads*2],v[num_heads*0:num_heads*1]],dim=1) #b,2n,d
re_sim = torch.bmm(re_q, re_k.transpose(-1, -2)) * attn.scale
re_sim[:,:,n:] = re_sim[:,:,n:] * self.style_guidance
re_attention_map = re_sim.softmax(-1)
h_s_re = torch.bmm(re_attention_map, v_re)
return h_s_re
|