OpenSound's picture
Upload 211 files
9d3cb0a verified
raw
history blame
3.4 kB
import typing as tp
from einops import rearrange
from librosa import filters
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
class ChromaExtractor(nn.Module):
"""Chroma extraction and quantization.
Args:
sample_rate (int): Sample rate for the chroma extraction.
n_chroma (int): Number of chroma bins for the chroma extraction.
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
nfft (int, optional): Number of FFT.
winlen (int, optional): Window length.
winhop (int, optional): Window hop size.
argmax (bool, optional): Whether to use argmax. Defaults to False.
norm (float, optional): Norm for chroma normalization. Defaults to inf.
"""
def __init__(self,
sample_rate: int,
n_chroma: int = 12, radix2_exp: int = 12,
nfft: tp.Optional[int] = None,
winlen: tp.Optional[int] = None,
winhop: tp.Optional[int] = None, argmax: bool = True,
norm: float = torch.inf):
super().__init__()
self.winlen = winlen or 2 ** radix2_exp
self.nfft = nfft or self.winlen
self.winhop = winhop or (self.winlen // 4)
self.sample_rate = sample_rate
self.n_chroma = n_chroma
self.norm = norm
self.argmax = argmax
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
n_chroma=self.n_chroma)), persistent=False)
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
hop_length=self.winhop, power=2, center=False,
pad=0, normalized=True)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
T = wav.shape[-1]
# in case we are getting a wav that was dropped out (nullified)
# from the conditioner, make sure wav length is no less that nfft
if T < self.nfft:
pad = self.nfft - T
r = 0 if pad % 2 == 0 else 1
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ),
int(self.nfft // 2 - self.winhop // 2 )), mode="reflect")
spec = self.spec(wav).squeeze(1)
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
if self.argmax:
idx = norm_chroma.argmax(-1, keepdim=True)
norm_chroma[:] = 0
norm_chroma.scatter_(dim=-1, index=idx, value=1)
return norm_chroma
if __name__ == "__main__":
chroma = ChromaExtractor(sample_rate=16000,
n_chroma=4,
radix2_exp=None,
winlen=16000,
nfft=16000,
winhop=4000)
audio = torch.rand(1, 16000)
c = chroma(audio)