Spaces:
Runtime error
Runtime error
File size: 7,481 Bytes
149cc2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from typing import Optional
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
self.n_heads = heads
self.d_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias = bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
try:
# You can install flash attention by cloning their Github repo,
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
# and then running `python setup.py install`
from flash_attn.flash_attention import FlashAttention
self.flash = FlashAttention()
# Set the scale for scaled dot-product attention.
self.flash.softmax_scale = self.scale
# Set to `None` if it's not installed
except ImportError:
self.flash = None
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
is_self = encoder_hidden_states is None
# attention, what we cannot get enough of
query = self.to_q(hidden_states)
has_cond = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
dim = query.shape[-1]
if self.flash is not None and not has_cond and self.d_head <= 64:
hidden_states = self.flash_attention(query, key, value)
else:
hidden_states = self.normal_attention(query, key, value, is_self)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Flash Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Get batch size and number of elements along sequence axis (`width * height`)
batch_size, seq_len, _ = q.shape
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
qkv = torch.stack((q, k, v), dim = 2)
# Split the heads
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
# fit this size.
if self.d_head <= 32:
pad = 32 - self.d_head
elif self.d_head <= 64:
pad = 64 - self.d_head
elif self.d_head <= 128:
pad = 128 - self.d_head
else:
raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
# Pad the heads
if pad:
qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim = -1)
# Compute attention
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
out, _ = self.flash(qkv)
# Truncate the extra head size
out = out[:, :, :, :self.d_head]
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return out
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_self: bool):
"""
#### Normal Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
q = q.view(*q.shape[:2], self.n_heads, -1)
k = k.view(*k.shape[:2], self.n_heads, -1)
v = v.view(*v.shape[:2], self.n_heads, -1)
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
half = attn.shape[0] // 2
attn[half:] = attn[half:].softmax(dim = -1)
attn[:half] = attn[:half].softmax(dim = -1)
# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to `[batch_size, height * width, n_heads * d_head]`
out = out.reshape(*out.shape[:2], -1)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return out |