Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,396 Bytes
9d3cb0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import typing as tp
from einops import rearrange
from librosa import filters
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
class ChromaExtractor(nn.Module):
"""Chroma extraction and quantization.
Args:
sample_rate (int): Sample rate for the chroma extraction.
n_chroma (int): Number of chroma bins for the chroma extraction.
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
nfft (int, optional): Number of FFT.
winlen (int, optional): Window length.
winhop (int, optional): Window hop size.
argmax (bool, optional): Whether to use argmax. Defaults to False.
norm (float, optional): Norm for chroma normalization. Defaults to inf.
"""
def __init__(self,
sample_rate: int,
n_chroma: int = 12, radix2_exp: int = 12,
nfft: tp.Optional[int] = None,
winlen: tp.Optional[int] = None,
winhop: tp.Optional[int] = None, argmax: bool = True,
norm: float = torch.inf):
super().__init__()
self.winlen = winlen or 2 ** radix2_exp
self.nfft = nfft or self.winlen
self.winhop = winhop or (self.winlen // 4)
self.sample_rate = sample_rate
self.n_chroma = n_chroma
self.norm = norm
self.argmax = argmax
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
n_chroma=self.n_chroma)), persistent=False)
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
hop_length=self.winhop, power=2, center=False,
pad=0, normalized=True)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
T = wav.shape[-1]
# in case we are getting a wav that was dropped out (nullified)
# from the conditioner, make sure wav length is no less that nfft
if T < self.nfft:
pad = self.nfft - T
r = 0 if pad % 2 == 0 else 1
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ),
int(self.nfft // 2 - self.winhop // 2 )), mode="reflect")
spec = self.spec(wav).squeeze(1)
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
if self.argmax:
idx = norm_chroma.argmax(-1, keepdim=True)
norm_chroma[:] = 0
norm_chroma.scatter_(dim=-1, index=idx, value=1)
return norm_chroma
if __name__ == "__main__":
chroma = ChromaExtractor(sample_rate=16000,
n_chroma=4,
radix2_exp=None,
winlen=16000,
nfft=16000,
winhop=4000)
audio = torch.rand(1, 16000)
c = chroma(audio) |