|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from modules.general.utils import Conv1d, normalization, zero_module |
|
from .basic import UNetBlock |
|
|
|
|
|
class AttentionBlock(UNetBlock): |
|
r"""A spatial transformer encoder block that allows spatial positions to attend |
|
to each other. Reference from `latent diffusion repo |
|
<https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_. |
|
|
|
Args: |
|
channels: Number of channels in the input. |
|
num_head_channels: Number of channels per attention head. |
|
num_heads: Number of attention heads. Overrides ``num_head_channels`` if set. |
|
encoder_channels: Number of channels in the encoder output for cross-attention. |
|
If ``None``, then self-attention is performed. |
|
use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set. |
|
dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images. |
|
h_dim: The dimension of the height, would be applied if ``dims`` is 2. |
|
encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2. |
|
p_dropout: Dropout probability. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
num_head_channels: int = 32, |
|
num_heads: int = -1, |
|
encoder_channels: int = None, |
|
use_self_attention: bool = False, |
|
dims: int = 1, |
|
h_dim: int = 100, |
|
encoder_hdim: int = 384, |
|
p_dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
|
|
self.channels = channels |
|
self.p_dropout = p_dropout |
|
self.dims = dims |
|
|
|
if dims == 1: |
|
self.channels = channels |
|
elif dims == 2: |
|
|
|
|
|
self.channels = channels * h_dim |
|
else: |
|
raise ValueError(f"invalid number of dimensions: {dims}") |
|
|
|
if num_head_channels == -1: |
|
assert ( |
|
self.channels % num_heads == 0 |
|
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" |
|
self.num_heads = num_heads |
|
self.num_head_channels = self.channels // num_heads |
|
else: |
|
assert ( |
|
self.channels % num_head_channels == 0 |
|
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" |
|
self.num_heads = self.channels // num_head_channels |
|
self.num_head_channels = num_head_channels |
|
|
|
if encoder_channels is not None: |
|
self.use_self_attention = use_self_attention |
|
|
|
if dims == 1: |
|
self.encoder_channels = encoder_channels |
|
elif dims == 2: |
|
self.encoder_channels = encoder_channels * encoder_hdim |
|
else: |
|
raise ValueError(f"invalid number of dimensions: {dims}") |
|
|
|
if use_self_attention: |
|
self.self_attention = BasicAttentionBlock( |
|
self.channels, |
|
self.num_head_channels, |
|
self.num_heads, |
|
p_dropout=self.p_dropout, |
|
) |
|
self.cross_attention = BasicAttentionBlock( |
|
self.channels, |
|
self.num_head_channels, |
|
self.num_heads, |
|
self.encoder_channels, |
|
p_dropout=self.p_dropout, |
|
) |
|
else: |
|
self.encoder_channels = None |
|
self.self_attention = BasicAttentionBlock( |
|
self.channels, |
|
self.num_head_channels, |
|
self.num_heads, |
|
p_dropout=self.p_dropout, |
|
) |
|
|
|
def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None): |
|
r""" |
|
Args: |
|
x: input tensor with shape [B x ``channels`` x ...] |
|
encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed. |
|
|
|
Returns: |
|
output tensor with shape [B x ``channels`` x ...] |
|
""" |
|
shape = x.size() |
|
x = x.reshape(shape[0], self.channels, -1).contiguous() |
|
|
|
if self.encoder_channels is None: |
|
assert ( |
|
encoder_output is None |
|
), "encoder_output must be None for self-attention." |
|
h = self.self_attention(x) |
|
|
|
else: |
|
assert ( |
|
encoder_output is not None |
|
), "encoder_output must be given for cross-attention." |
|
encoder_output = encoder_output.reshape( |
|
shape[0], self.encoder_channels, -1 |
|
).contiguous() |
|
|
|
if self.use_self_attention: |
|
x = self.self_attention(x) |
|
h = self.cross_attention(x, encoder_output) |
|
|
|
return h.reshape(*shape).contiguous() |
|
|
|
|
|
class BasicAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels: int, |
|
num_head_channels: int = 32, |
|
num_heads: int = -1, |
|
context_channels: int = None, |
|
p_dropout: float = 0.0, |
|
): |
|
super().__init__() |
|
|
|
self.channels = channels |
|
self.p_dropout = p_dropout |
|
self.context_channels = context_channels |
|
|
|
if num_head_channels == -1: |
|
assert ( |
|
self.channels % num_heads == 0 |
|
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" |
|
self.num_heads = num_heads |
|
self.num_head_channels = self.channels // num_heads |
|
else: |
|
assert ( |
|
self.channels % num_head_channels == 0 |
|
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" |
|
self.num_heads = self.channels // num_head_channels |
|
self.num_head_channels = num_head_channels |
|
|
|
if context_channels is not None: |
|
self.to_q = nn.Sequential( |
|
normalization(self.channels), |
|
Conv1d(self.channels, self.channels, 1), |
|
) |
|
self.to_kv = Conv1d(context_channels, 2 * self.channels, 1) |
|
else: |
|
self.to_qkv = nn.Sequential( |
|
normalization(self.channels), |
|
Conv1d(self.channels, 3 * self.channels, 1), |
|
) |
|
|
|
self.linear = Conv1d(self.channels, self.channels) |
|
|
|
self.proj_out = nn.Sequential( |
|
normalization(self.channels), |
|
Conv1d(self.channels, self.channels, 1), |
|
nn.GELU(), |
|
nn.Dropout(p=self.p_dropout), |
|
zero_module(Conv1d(self.channels, self.channels, 1)), |
|
) |
|
|
|
def forward(self, q: torch.Tensor, kv: torch.Tensor = None): |
|
r""" |
|
Args: |
|
q: input tensor with shape [B, ``channels``, L] |
|
kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed. |
|
|
|
Returns: |
|
output tensor with shape [B, ``channels``, L] |
|
""" |
|
N, C, L = q.size() |
|
|
|
if self.context_channels is not None: |
|
assert kv is not None, "kv must be given for cross-attention." |
|
|
|
q = ( |
|
self.to_q(q) |
|
.reshape(self.num_heads, self.num_head_channels, -1) |
|
.transpose(-1, -2) |
|
.contiguous() |
|
) |
|
kv = ( |
|
self.to_kv(kv) |
|
.reshape(2, self.num_heads, self.num_head_channels, -1) |
|
.transpose(-1, -2) |
|
.chunk(2) |
|
) |
|
k, v = ( |
|
kv[0].squeeze(0).contiguous(), |
|
kv[1].squeeze(0).contiguous(), |
|
) |
|
|
|
else: |
|
qkv = ( |
|
self.to_qkv(q) |
|
.reshape(3, self.num_heads, self.num_head_channels, -1) |
|
.transpose(-1, -2) |
|
.chunk(3) |
|
) |
|
q, k, v = ( |
|
qkv[0].squeeze(0).contiguous(), |
|
qkv[1].squeeze(0).contiguous(), |
|
qkv[2].squeeze(0).contiguous(), |
|
) |
|
|
|
h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose( |
|
-1, -2 |
|
) |
|
h = h.reshape(N, -1, L).contiguous() |
|
h = self.linear(h) |
|
|
|
x = q + h |
|
h = self.proj_out(x) |
|
|
|
return x + h |
|
|