StevenChen16's picture
first commit
31ba7c5
raw
history blame
55.5 kB
# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
# License can be found in LICENSES/LICENSE_ADP.txt
import math
from inspect import isfunction
from math import ceil, floor, log, pi, log2
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from packaging import version
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many
from torch import Tensor, einsum
from torch.backends.cuda import sdp_kernel
from torch.nn import functional as F
from dac.nn.layers import Snake1d
import pdb
"""
Utils
"""
class ConditionedSequential(nn.Module):
def __init__(self, *modules):
super().__init__()
self.module_list = nn.ModuleList(*modules)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
for module in self.module_list:
x = module(x, mapping)
return x
T = TypeVar("T")
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d
def exists(val: Optional[T]) -> T:
return val is not None
def closest_power_2(x: float) -> int:
exponent = log2(x)
distance_fn = lambda z: abs(x - 2 ** z) # noqa
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs
"""
Convolutional Blocks
"""
import typing as tp
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
# License available in LICENSES/LICENSE_META.txt
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
kernel_size = self.kernel_size[0]
stride = self.stride[0]
dilation = self.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding))
return super().forward(x)
class ConvTranspose1d(nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
kernel_size = self.kernel_size[0]
stride = self.stride[0]
padding_total = kernel_size - stride
y = super().forward(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if causal:
padding_right = ceil(padding_total)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
def Downsample1d(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor
)
def Upsample1d(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3
),
)
else:
return ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor
)
class ConvBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
num_groups: int = 8,
use_norm: bool = True,
use_snake: bool = False
) -> None:
super().__init__()
self.groupnorm = (
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
if use_norm
else nn.Identity()
)
if use_snake:
self.activation = Snake1d(in_channels)
else:
self.activation = nn.SiLU()
self.project = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
def forward(
self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
) -> Tensor:
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x, causal=causal)
class MappingToScaleShift(nn.Module):
def __init__(
self,
features: int,
channels: int,
):
super().__init__()
self.to_scale_shift = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=features, out_features=channels * 2),
)
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
scale_shift = self.to_scale_shift(mapping)
scale_shift = rearrange(scale_shift, "b c -> b c 1")
scale, shift = scale_shift.chunk(2, dim=1)
return scale, shift
class ResnetBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
use_norm: bool = True,
use_snake: bool = False,
num_groups: int = 8,
context_mapping_features: Optional[int] = None,
) -> None:
super().__init__()
self.use_mapping = exists(context_mapping_features)
self.block1 = ConvBlock1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
if self.use_mapping:
assert exists(context_mapping_features)
self.to_scale_shift = MappingToScaleShift(
features=context_mapping_features, channels=out_channels
)
self.block2 = ConvBlock1d(
in_channels=out_channels,
out_channels=out_channels,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
self.to_out = (
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
assert_message = "context mapping required if context_mapping_features > 0"
assert not (self.use_mapping ^ exists(mapping)), assert_message
h = self.block1(x, causal=causal)
scale_shift = None
if self.use_mapping:
scale_shift = self.to_scale_shift(mapping)
h = self.block2(h, scale_shift=scale_shift, causal=causal)
return h + self.to_out(x)
class Patcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
assert out_channels % patch_size == 0, assert_message
self.patch_size = patch_size
self.block = ResnetBlock1d(
in_channels=in_channels,
out_channels=out_channels // patch_size,
num_groups=1,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
x = self.block(x, mapping, causal=causal)
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
return x
class Unpatcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False
):
super().__init__()
assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
assert in_channels % patch_size == 0, assert_message
self.patch_size = patch_size
self.block = ResnetBlock1d(
in_channels=in_channels // patch_size,
out_channels=out_channels,
num_groups=1,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
x = self.block(x, mapping, causal=causal)
return x
"""
Attention Components
"""
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
b, ndim = sim.shape[0], mask.ndim
if ndim == 3:
mask = rearrange(mask, "b n m -> b 1 n m")
if ndim == 2:
mask = repeat(mask, "n m -> b 1 n m", b=b)
max_neg_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~mask, max_neg_value)
return sim
def causal_mask(q: Tensor, k: Tensor) -> Tensor:
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
mask = repeat(mask, "n m -> b n m", b=b)
return mask
class AttentionBase(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
):
super().__init__()
self.scale = head_features**-0.5
self.num_heads = num_heads
mid_features = head_features * num_heads
out_features = default(out_features, features)
self.to_out = nn.Linear(
in_features=mid_features, out_features=out_features
)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
) -> Tensor:
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
if not self.use_flash:
if is_causal and not mask:
# Mask out future tokens for causal attention
mask = causal_mask(q, k)
# Compute similarity matrix and add eventual mask
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
sim = add_mask(sim, mask) if exists(mask) else sim
# Get attention matrix with softmax
attn = sim.softmax(dim=-1, dtype=torch.float32)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
else:
with sdp_kernel(*self.sdp_kernel_config):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
context_features: Optional[int] = None,
causal: bool = False,
):
super().__init__()
self.context_features = context_features
self.causal = causal
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
num_heads=num_heads,
head_features=head_features,
out_features=out_features,
)
def forward(
self,
x: Tensor, # [b, n, c]
context: Optional[Tensor] = None, # [b, m, d]
context_mask: Optional[Tensor] = None, # [b, m], false is masked,
causal: Optional[bool] = False,
) -> Tensor:
assert_message = "You must provide a context when using context_features"
assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x), self.norm_context(context)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
if exists(context_mask):
# Mask out cross-attention for padding tokens
mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
k, v = k * mask, v * mask
# Compute and return attention
return self.attention(q, k, v, is_causal=self.causal or causal)
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
"""
Transformer Blocks
"""
class TransformerBlock(nn.Module):
def __init__(
self,
features: int,
num_heads: int,
head_features: int,
multiplier: int,
context_features: Optional[int] = None,
):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features
)
if self.use_cross_attention:
self.cross_attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features,
context_features=context_features
)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
x = self.attention(x, causal=causal) + x
if self.use_cross_attention:
x = self.cross_attention(x, context=context, context_mask=context_mask) + x
x = self.feed_forward(x) + x
return x
"""
Transformers
"""
class Transformer1d(nn.Module):
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
context_features: Optional[int] = None,
):
super().__init__()
self.to_in = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
),
Rearrange("b c t -> b t c"),
)
self.blocks = nn.ModuleList(
[
TransformerBlock(
features=channels,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
context_features=context_features,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
),
)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
x = self.to_in(x)
for block in self.blocks:
x = block(x, context=context, context_mask=context_mask, causal=causal)
x = self.to_out(x)
return x
"""
Time Embeddings
"""
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)
"""
Encoder/Decoder Components
"""
class DownsampleBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
factor: int,
num_groups: int,
num_layers: int,
kernel_multiplier: int = 2,
use_pre_downsample: bool = True,
use_skip: bool = False,
use_snake: bool = False,
extract_channels: int = 0,
context_channels: int = 0,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
super().__init__()
self.use_pre_downsample = use_pre_downsample
self.use_skip = use_skip
self.use_transformer = num_transformer_blocks > 0
self.use_extract = extract_channels > 0
self.use_context = context_channels > 0
channels = out_channels if use_pre_downsample else in_channels
self.downsample = Downsample1d(
in_channels=in_channels,
out_channels=out_channels,
factor=factor,
kernel_multiplier=kernel_multiplier,
)
self.blocks = nn.ModuleList(
[
ResnetBlock1d(
in_channels=channels + context_channels if i == 0 else channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
for i in range(num_layers)
]
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features
)
if self.use_extract:
num_extract_groups = min(num_groups, extract_channels)
self.to_extracted = ResnetBlock1d(
in_channels=out_channels,
out_channels=extract_channels,
num_groups=num_extract_groups,
use_snake=use_snake
)
def forward(
self,
x: Tensor,
*,
mapping: Optional[Tensor] = None,
channels: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
if self.use_pre_downsample:
x = self.downsample(x)
if self.use_context and exists(channels):
x = torch.cat([x, channels], dim=1)
skips = []
for block in self.blocks:
x = block(x, mapping=mapping, causal=causal)
skips += [x] if self.use_skip else []
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
skips += [x] if self.use_skip else []
if not self.use_pre_downsample:
x = self.downsample(x)
if self.use_extract:
extracted = self.to_extracted(x)
return x, extracted
return (x, skips) if self.use_skip else x
class UpsampleBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
factor: int,
num_layers: int,
num_groups: int,
use_nearest: bool = False,
use_pre_upsample: bool = False,
use_skip: bool = False,
use_snake: bool = False,
skip_channels: int = 0,
use_skip_scale: bool = False,
extract_channels: int = 0,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
super().__init__()
self.use_extract = extract_channels > 0
self.use_pre_upsample = use_pre_upsample
self.use_transformer = num_transformer_blocks > 0
self.use_skip = use_skip
self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
channels = out_channels if use_pre_upsample else in_channels
self.blocks = nn.ModuleList(
[
ResnetBlock1d(
in_channels=channels + skip_channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
for _ in range(num_layers)
]
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
)
self.upsample = Upsample1d(
in_channels=in_channels,
out_channels=out_channels,
factor=factor,
use_nearest=use_nearest,
)
if self.use_extract:
num_extract_groups = min(num_groups, extract_channels)
self.to_extracted = ResnetBlock1d(
in_channels=out_channels,
out_channels=extract_channels,
num_groups=num_extract_groups,
use_snake=use_snake
)
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
return torch.cat([x, skip * self.skip_scale], dim=1)
def forward(
self,
x: Tensor,
*,
skips: Optional[List[Tensor]] = None,
mapping: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Union[Tuple[Tensor, Tensor], Tensor]:
if self.use_pre_upsample:
x = self.upsample(x)
for block in self.blocks:
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
x = block(x, mapping=mapping, causal=causal)
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
if not self.use_pre_upsample:
x = self.upsample(x)
if self.use_extract:
extracted = self.to_extracted(x)
return x, extracted
return x
class BottleneckBlock1d(nn.Module):
def __init__(
self,
channels: int,
*,
num_groups: int,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
self.use_transformer = num_transformer_blocks > 0
self.pre_block = ResnetBlock1d(
in_channels=channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
)
self.post_block = ResnetBlock1d(
in_channels=channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(
self,
x: Tensor,
*,
mapping: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Tensor:
x = self.pre_block(x, mapping=mapping, causal=causal)
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
x = self.post_block(x, mapping=mapping, causal=causal)
return x
"""
UNet
"""
class UNet1d(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
attentions: Sequence[int],
patch_size: int = 1,
resnet_groups: int = 8,
use_context_time: bool = True,
kernel_multiplier_downsample: int = 2,
use_nearest_upsample: bool = False,
use_skip_scale: bool = True,
use_snake: bool = False,
use_stft: bool = False,
use_stft_context: bool = False,
out_channels: Optional[int] = None,
context_features: Optional[int] = None,
context_features_multiplier: int = 4,
context_channels: Optional[Sequence[int]] = None,
context_embedding_features: Optional[int] = None,
**kwargs,
):
super().__init__()
out_channels = default(out_channels, in_channels)
context_channels = list(default(context_channels, []))
num_layers = len(multipliers) - 1
use_context_features = exists(context_features)
use_context_channels = len(context_channels) > 0
context_mapping_features = None
attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
self.num_layers = num_layers
self.use_context_time = use_context_time
self.use_context_features = use_context_features
self.use_context_channels = use_context_channels
self.use_stft = use_stft
self.use_stft_context = use_stft_context
self.context_features = context_features
context_channels_pad_length = num_layers + 1 - len(context_channels)
context_channels = context_channels + [0] * context_channels_pad_length
self.context_channels = context_channels
self.context_embedding_features = context_embedding_features
if use_context_channels:
has_context = [c > 0 for c in context_channels]
self.has_context = has_context
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
assert (
len(factors) == num_layers
and len(attentions) >= num_layers
and len(num_blocks) == num_layers
)
if use_context_time or use_context_features:
context_mapping_features = channels * context_features_multiplier
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
assert exists(context_mapping_features)
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
if use_stft:
stft_kwargs, kwargs = groupby("stft_", kwargs)
assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
in_channels *= stft_channels
out_channels *= stft_channels
context_channels[0] *= stft_channels if use_stft_context else 1
assert exists(in_channels) and exists(out_channels)
self.stft = STFT(**stft_kwargs)
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
self.to_in = Patcher(
in_channels=in_channels + context_channels[0],
out_channels=channels * multipliers[0],
patch_size=patch_size,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
self.downsamples = nn.ModuleList(
[
DownsampleBlock1d(
in_channels=channels * multipliers[i],
out_channels=channels * multipliers[i + 1],
context_mapping_features=context_mapping_features,
context_channels=context_channels[i + 1],
context_embedding_features=context_embedding_features,
num_layers=num_blocks[i],
factor=factors[i],
kernel_multiplier=kernel_multiplier_downsample,
num_groups=resnet_groups,
use_pre_downsample=True,
use_skip=True,
use_snake=use_snake,
num_transformer_blocks=attentions[i],
**attention_kwargs,
)
for i in range(num_layers)
]
)
self.bottleneck = BottleneckBlock1d(
channels=channels * multipliers[-1],
context_mapping_features=context_mapping_features,
context_embedding_features=context_embedding_features,
num_groups=resnet_groups,
num_transformer_blocks=attentions[-1],
use_snake=use_snake,
**attention_kwargs,
)
self.upsamples = nn.ModuleList(
[
UpsampleBlock1d(
in_channels=channels * multipliers[i + 1],
out_channels=channels * multipliers[i],
context_mapping_features=context_mapping_features,
context_embedding_features=context_embedding_features,
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
factor=factors[i],
use_nearest=use_nearest_upsample,
num_groups=resnet_groups,
use_skip_scale=use_skip_scale,
use_pre_upsample=False,
use_skip=True,
use_snake=use_snake,
skip_channels=channels * multipliers[i + 1],
num_transformer_blocks=attentions[i],
**attention_kwargs,
)
for i in reversed(range(num_layers))
]
)
self.to_out = Unpatcher(
in_channels=channels * multipliers[0],
out_channels=out_channels,
patch_size=patch_size,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def get_channels(
self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
) -> Optional[Tensor]:
"""Gets context channels at `layer` and checks that shape is correct"""
use_context_channels = self.use_context_channels and self.has_context[layer]
if not use_context_channels:
return None
assert exists(channels_list), "Missing context"
# Get channels index (skipping zero channel contexts)
channels_id = self.channels_ids[layer]
# Get channels
channels = channels_list[channels_id]
message = f"Missing context for layer {layer} at index {channels_id}"
assert exists(channels), message
# Check channels
num_channels = self.context_channels[layer]
message = f"Expected context with {num_channels} channels at idx {channels_id}"
assert channels.shape[1] == num_channels, message
# STFT channels if requested
channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
return channels
def get_mapping(
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
) -> Optional[Tensor]:
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
assert_message = "use_context_time=True but no time features provided"
assert exists(time), assert_message
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
assert_message = "context_features exists but no features provided"
assert exists(features), assert_message
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def forward(
self,
x: Tensor,
time: Optional[Tensor] = None,
*,
features: Optional[Tensor] = None,
channels_list: Optional[Sequence[Tensor]] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False,
) -> Tensor:
channels = self.get_channels(channels_list, layer=0)
# Apply stft if required
print(x.shape)
x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
print(x.shape)
# Concat context channels at layer 0 if provided
x = torch.cat([x, channels], dim=1) if exists(channels) else x
print(x.shape)
# Compute mapping from time and features
mapping = self.get_mapping(time, features)
x = self.to_in(x, mapping, causal=causal)
print(x.shape)
skips_list = [x]
for i, downsample in enumerate(self.downsamples):
channels = self.get_channels(channels_list, layer=i + 1)
x, skips = downsample(
x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
)
skips_list += [skips]
x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
for i, upsample in enumerate(self.upsamples):
skips = skips_list.pop()
x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
x += skips_list.pop()
x = self.to_out(x, mapping, causal=causal)
x = self.stft.decode1d(x) if self.use_stft else x
return x
""" Conditioning Modules """
class FixedEmbedding(nn.Module):
def __init__(self, max_length: int, features: int):
super().__init__()
self.max_length = max_length
self.embedding = nn.Embedding(max_length, features)
def forward(self, x: Tensor) -> Tensor:
batch_size, length, device = *x.shape[0:2], x.device
assert_message = "Input sequence length must be <= max_length"
assert length <= self.max_length, assert_message
position = torch.arange(length, device=device)
fixed_embedding = self.embedding(position)
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
return fixed_embedding
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
class UNetCFG1d(UNet1d):
"""UNet1d with Classifier-Free Guidance"""
def __init__(
self,
context_embedding_max_length: int,
context_embedding_features: int,
use_xattn_time: bool = False,
**kwargs,
):
super().__init__(
context_embedding_features=context_embedding_features, **kwargs
)
self.use_xattn_time = use_xattn_time
if use_xattn_time:
assert exists(context_embedding_features)
self.to_time_embedding = nn.Sequential(
TimePositionalEmbedding(
dim=kwargs["channels"], out_features=context_embedding_features
),
nn.GELU(),
)
context_embedding_max_length += 1 # Add one for time embedding
self.fixed_embedding = FixedEmbedding(
max_length=context_embedding_max_length, features=context_embedding_features
)
def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
embedding: Tensor,
embedding_mask: Optional[Tensor] = None,
embedding_scale: float = 1.0,
embedding_mask_proba: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
scale_phi: float = 0.4,
negative_embedding: Optional[Tensor] = None,
negative_embedding_mask: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
b, device = embedding.shape[0], embedding.device
if self.use_xattn_time:
embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
if embedding_mask is not None:
embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
if batch_cfg:
batch_x = torch.cat([x, x], dim=0)
batch_time = torch.cat([time, time], dim=0)
if negative_embedding is not None:
if negative_embedding_mask is not None:
negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
batch_embed = torch.cat([embedding, negative_embedding], dim=0)
else:
batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
batch_mask = None
if embedding_mask is not None:
batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
batch_features = None
features = kwargs.pop("features", None)
if self.use_context_features:
batch_features = torch.cat([features, features], dim=0)
batch_channels = None
channels_list = kwargs.pop("channels_list", None)
if self.use_context_channels:
batch_channels = []
for channels in channels_list:
batch_channels += [torch.cat([channels, channels], dim=0)]
# Compute both normal and fixed embedding outputs
batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
out, out_masked = batch_out.chunk(2, dim=0)
else:
# Compute both normal and fixed embedding outputs
out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
out_cfg = out_masked + (out - out_masked) * embedding_scale
if rescale_cfg:
out_std = out.std(dim=1, keepdim=True)
out_cfg_std = out_cfg.std(dim=1, keepdim=True)
return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
else:
return out_cfg
else:
return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
class UNetNCCA1d(UNet1d):
"""UNet1d with Noise Channel Conditioning Augmentation"""
def __init__(self, context_features: int, **kwargs):
super().__init__(context_features=context_features, **kwargs)
self.embedder = NumberEmbedder(features=context_features)
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
x = x if torch.is_tensor(x) else torch.tensor(x)
return x.expand(shape)
def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
channels_list: Sequence[Tensor],
channels_augmentation: Union[
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
] = False,
channels_scale: Union[
float, Sequence[float], Sequence[Sequence[float]], Tensor
] = 0,
**kwargs,
) -> Tensor:
b, n = x.shape[0], len(channels_list)
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
# Augmentation (for each channel list item)
for i in range(n):
scale = channels_scale[:, i] * channels_augmentation[:, i]
scale = rearrange(scale, "b -> b 1 1")
item = channels_list[i]
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
# Scale embedding (sum reduction if more than one channel list item)
channels_scale_emb = self.embedder(channels_scale)
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
return super().forward(
x=x,
time=time,
channels_list=channels_list,
features=channels_scale_emb,
**kwargs,
)
class UNetAll1d(UNetCFG1d, UNetNCCA1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs): # type: ignore
return UNetCFG1d.forward(self, *args, **kwargs)
def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
if type == "base":
return UNet1d(**kwargs)
elif type == "all":
return UNetAll1d(**kwargs)
elif type == "cfg":
return UNetCFG1d(**kwargs)
elif type == "ncca":
return UNetNCCA1d(**kwargs)
else:
raise ValueError(f"Unknown XUNet1d type: {type}")
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
"""
Audio Transforms
"""
class STFT(nn.Module):
"""Helper for torch stft and istft"""
def __init__(
self,
num_fft: int = 1023,
hop_length: int = 256,
window_length: Optional[int] = None,
length: Optional[int] = None,
use_complex: bool = False,
):
super().__init__()
self.num_fft = num_fft
self.hop_length = default(hop_length, floor(num_fft // 4))
self.window_length = default(window_length, num_fft)
self.length = length
self.register_buffer("window", torch.hann_window(self.window_length))
self.use_complex = use_complex
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
b = wave.shape[0]
wave = rearrange(wave, "b c t -> (b c) t")
stft = torch.stft(
wave,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
return_complex=True,
normalized=True,
)
if self.use_complex:
# Returns real and imaginary
stft_a, stft_b = stft.real, stft.imag
else:
# Returns magnitude and phase matrices
magnitude, phase = torch.abs(stft), torch.angle(stft)
stft_a, stft_b = magnitude, phase
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
length = closest_power_2(l * self.hop_length)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
if self.use_complex:
real, imag = stft_a, stft_b
else:
magnitude, phase = stft_a, stft_b
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
stft = torch.stack([real, imag], dim=-1)
wave = torch.istft(
stft,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
length=default(self.length, length),
normalized=True,
)
return rearrange(wave, "(b c) t -> b c t", b=b)
def encode1d(
self, wave: Tensor, stacked: bool = True
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
stft_a, stft_b = self.encode(wave)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
def decode1d(self, stft_pair: Tensor) -> Tensor:
f = self.num_fft // 2 + 1
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
return self.decode(stft_a, stft_b)