File size: 3,793 Bytes
a6cec16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn.functional as nnf
import abc
import math
from torchvision.utils import save_image


LOW_RESOURCE = False
MAX_NUM_WORDS = 77
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32



class AttentionControl(abc.ABC):

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def start_att_layers(self):
        return self.start_ac_layer #if LOW_RESOURCE else 0
    @property
    def end_att_layers(self):
        return self.end_ac_layer
    
    
    

    @abc.abstractmethod
    def forward(self, q, k, v, num_heads,attn):
        raise NotImplementedError

    def attn_forward(self, q, k, v, num_heads,attention_probs,attn):
        if q.shape[0]//num_heads == 3:
            h_s_re = self.forward(q, k, v, num_heads,attention_probs, attn)
            
        else:
            uq,cq = q.chunk(2)
            uk,ck = k.chunk(2)
            uv,cv = v.chunk(2)
            u_attn, c_attn = attention_probs.chunk(2)
            
            u_h_s_re = self.forward(uq, uk, uv, num_heads,u_attn, attn)

            c_h_s_re = self.forward(cq, ck, cv, num_heads,c_attn, attn)
            h_s_re = (u_h_s_re, c_h_s_re)
        return h_s_re
    
    def __call__(self, q, k, v, num_heads,attention_probs,attn):

        
        if self.cur_att_layer >= self.start_att_layers and self.cur_att_layer < self.end_att_layers:
            h_s_re = self.attn_forward(q, k, v, num_heads,attention_probs,attn)
        else:
            h_s_re=None
        
        
        self.cur_att_layer += 1
        
        if self.cur_att_layer == self.num_att_layers // 2: #+ self.num_uncond_att_layers:
            self.cur_att_layer = 0 #self.num_uncond_att_layers
            self.cur_step += 1
            self.between_steps()
        return h_s_re

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor:
    """ Compute the attention map contrasting. """
    mean_feat = tensor.mean(dim=-1, keepdims=True)
    adjusted_tensor = (tensor - mean_feat) * contrast_factor + mean_feat
    return adjusted_tensor

class AttentionStyle(AttentionControl):

    def __init__(self, 
                 num_steps,
                 start_ac_layer, end_ac_layer,
                 style_guidance=0.3,
                 mix_q_scale=1.0,
                 de_bug=False,
                 ):
        super(AttentionStyle, self).__init__()


        self.start_ac_layer = start_ac_layer
        self.end_ac_layer = end_ac_layer
        self.num_steps=num_steps
        self.de_bug = de_bug
        self.style_guidance = style_guidance
        self.coef = None
        self.mix_q_scale = mix_q_scale
        
    def forward(self, q, k, v, num_heads, attention_probs, attn):

        
        if self.de_bug:
                import pdb; pdb.set_trace()

        if self.mix_q_scale < 1.0:
            q[num_heads*2:] = q[num_heads*2:] * self.mix_q_scale + (1 - self.mix_q_scale) * q[num_heads*1:num_heads*2]
        b,n,d = k.shape
        re_q = q[num_heads*2:] # b,n,d,
        re_k = torch.cat([k[num_heads*1:num_heads*2],k[num_heads*0:num_heads*1]],dim=1) #b,2n,d
        v_re = torch.cat([v[num_heads*1:num_heads*2],v[num_heads*0:num_heads*1]],dim=1) #b,2n,d
        re_sim = torch.bmm(re_q, re_k.transpose(-1, -2)) * attn.scale
        re_sim[:,:,n:] = re_sim[:,:,n:] * self.style_guidance
        re_attention_map = re_sim.softmax(-1)
        h_s_re = torch.bmm(re_attention_map, v_re)



        return h_s_re