jiaqili3's picture
init
addb7e5
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.general.utils import Conv1d, normalization, zero_module
from .basic import UNetBlock
class AttentionBlock(UNetBlock):
r"""A spatial transformer encoder block that allows spatial positions to attend
to each other. Reference from `latent diffusion repo
<https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_.
Args:
channels: Number of channels in the input.
num_head_channels: Number of channels per attention head.
num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
encoder_channels: Number of channels in the encoder output for cross-attention.
If ``None``, then self-attention is performed.
use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
h_dim: The dimension of the height, would be applied if ``dims`` is 2.
encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
p_dropout: Dropout probability.
"""
def __init__(
self,
channels: int,
num_head_channels: int = 32,
num_heads: int = -1,
encoder_channels: int = None,
use_self_attention: bool = False,
dims: int = 1,
h_dim: int = 100,
encoder_hdim: int = 384,
p_dropout: float = 0.0,
):
super().__init__()
self.channels = channels
self.p_dropout = p_dropout
self.dims = dims
if dims == 1:
self.channels = channels
elif dims == 2:
# We consider the channel as product of channel and height, i.e. C x H
# This is because we want to apply attention on the audio signal, which is 1D
self.channels = channels * h_dim
else:
raise ValueError(f"invalid number of dimensions: {dims}")
if num_head_channels == -1:
assert (
self.channels % num_heads == 0
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.num_head_channels = self.channels // num_heads
else:
assert (
self.channels % num_head_channels == 0
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = self.channels // num_head_channels
self.num_head_channels = num_head_channels
if encoder_channels is not None:
self.use_self_attention = use_self_attention
if dims == 1:
self.encoder_channels = encoder_channels
elif dims == 2:
self.encoder_channels = encoder_channels * encoder_hdim
else:
raise ValueError(f"invalid number of dimensions: {dims}")
if use_self_attention:
self.self_attention = BasicAttentionBlock(
self.channels,
self.num_head_channels,
self.num_heads,
p_dropout=self.p_dropout,
)
self.cross_attention = BasicAttentionBlock(
self.channels,
self.num_head_channels,
self.num_heads,
self.encoder_channels,
p_dropout=self.p_dropout,
)
else:
self.encoder_channels = None
self.self_attention = BasicAttentionBlock(
self.channels,
self.num_head_channels,
self.num_heads,
p_dropout=self.p_dropout,
)
def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
r"""
Args:
x: input tensor with shape [B x ``channels`` x ...]
encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.
Returns:
output tensor with shape [B x ``channels`` x ...]
"""
shape = x.size()
x = x.reshape(shape[0], self.channels, -1).contiguous()
if self.encoder_channels is None:
assert (
encoder_output is None
), "encoder_output must be None for self-attention."
h = self.self_attention(x)
else:
assert (
encoder_output is not None
), "encoder_output must be given for cross-attention."
encoder_output = encoder_output.reshape(
shape[0], self.encoder_channels, -1
).contiguous()
if self.use_self_attention:
x = self.self_attention(x)
h = self.cross_attention(x, encoder_output)
return h.reshape(*shape).contiguous()
class BasicAttentionBlock(nn.Module):
def __init__(
self,
channels: int,
num_head_channels: int = 32,
num_heads: int = -1,
context_channels: int = None,
p_dropout: float = 0.0,
):
super().__init__()
self.channels = channels
self.p_dropout = p_dropout
self.context_channels = context_channels
if num_head_channels == -1:
assert (
self.channels % num_heads == 0
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.num_head_channels = self.channels // num_heads
else:
assert (
self.channels % num_head_channels == 0
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = self.channels // num_head_channels
self.num_head_channels = num_head_channels
if context_channels is not None:
self.to_q = nn.Sequential(
normalization(self.channels),
Conv1d(self.channels, self.channels, 1),
)
self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
else:
self.to_qkv = nn.Sequential(
normalization(self.channels),
Conv1d(self.channels, 3 * self.channels, 1),
)
self.linear = Conv1d(self.channels, self.channels)
self.proj_out = nn.Sequential(
normalization(self.channels),
Conv1d(self.channels, self.channels, 1),
nn.GELU(),
nn.Dropout(p=self.p_dropout),
zero_module(Conv1d(self.channels, self.channels, 1)),
)
def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
r"""
Args:
q: input tensor with shape [B, ``channels``, L]
kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.
Returns:
output tensor with shape [B, ``channels``, L]
"""
N, C, L = q.size()
if self.context_channels is not None:
assert kv is not None, "kv must be given for cross-attention."
q = (
self.to_q(q)
.reshape(self.num_heads, self.num_head_channels, -1)
.transpose(-1, -2)
.contiguous()
)
kv = (
self.to_kv(kv)
.reshape(2, self.num_heads, self.num_head_channels, -1)
.transpose(-1, -2)
.chunk(2)
)
k, v = (
kv[0].squeeze(0).contiguous(),
kv[1].squeeze(0).contiguous(),
)
else:
qkv = (
self.to_qkv(q)
.reshape(3, self.num_heads, self.num_head_channels, -1)
.transpose(-1, -2)
.chunk(3)
)
q, k, v = (
qkv[0].squeeze(0).contiguous(),
qkv[1].squeeze(0).contiguous(),
qkv[2].squeeze(0).contiguous(),
)
h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
-1, -2
)
h = h.reshape(N, -1, L).contiguous()
h = self.linear(h)
x = q + h
h = self.proj_out(x)
return x + h