File size: 11,385 Bytes
f2cd0c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
from typing import Optional,Union
try:
from typing import Literal
except Exception as e:
from typing_extensions import Literal
import numpy as np
import torch
import torchcrepe
from torch import nn
from torch.nn import functional as F
import scipy
#from:https://github.com/fishaudio/fish-diffusion
def repeat_expand(
content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
):
"""Repeat content to target length.
This is a wrapper of torch.nn.functional.interpolate.
Args:
content (torch.Tensor): tensor
target_len (int): target length
mode (str, optional): interpolation mode. Defaults to "nearest".
Returns:
torch.Tensor: tensor
"""
ndim = content.ndim
if content.ndim == 1:
content = content[None, None]
elif content.ndim == 2:
content = content[None]
assert content.ndim == 3
is_np = isinstance(content, np.ndarray)
if is_np:
content = torch.from_numpy(content)
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
if is_np:
results = results.numpy()
if ndim == 1:
return results[0, 0]
elif ndim == 2:
return results[0]
class BasePitchExtractor:
def __init__(
self,
hop_length: int = 512,
f0_min: float = 50.0,
f0_max: float = 1100.0,
keep_zeros: bool = True,
):
"""Base pitch extractor.
Args:
hop_length (int, optional): Hop length. Defaults to 512.
f0_min (float, optional): Minimum f0. Defaults to 50.0.
f0_max (float, optional): Maximum f0. Defaults to 1100.0.
keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True.
"""
self.hop_length = hop_length
self.f0_min = f0_min
self.f0_max = f0_max
self.keep_zeros = keep_zeros
def __call__(self, x, sampling_rate=44100, pad_to=None):
raise NotImplementedError("BasePitchExtractor is not callable.")
def post_process(self, x, sampling_rate, f0, pad_to):
if isinstance(f0, np.ndarray):
f0 = torch.from_numpy(f0).float().to(x.device)
if pad_to is None:
return f0
f0 = repeat_expand(f0, pad_to)
if self.keep_zeros:
return f0
vuv_vector = torch.zeros_like(f0)
vuv_vector[f0 > 0.0] = 1.0
vuv_vector[f0 <= 0.0] = 0.0
# 去掉0频率, 并线性插值
nzindex = torch.nonzero(f0).squeeze()
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
if f0.shape[0] <= 0:
return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device)
if f0.shape[0] == 1:
return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device)
# 大概可以用 torch 重写?
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
vuv_vector = vuv_vector.cpu().numpy()
vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
return f0,vuv_vector
class MaskedAvgPool1d(nn.Module):
def __init__(
self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0
):
"""An implementation of mean pooling that supports masked values.
Args:
kernel_size (int): The size of the median pooling window.
stride (int, optional): The stride of the median pooling window. Defaults to None.
padding (int, optional): The padding of the median pooling window. Defaults to 0.
"""
super(MaskedAvgPool1d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding
def forward(self, x, mask=None):
ndim = x.dim()
if ndim == 2:
x = x.unsqueeze(1)
assert (
x.dim() == 3
), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)"
# Apply the mask by setting masked elements to zero, or make NaNs zero
if mask is None:
mask = ~torch.isnan(x)
# Ensure mask has the same shape as the input tensor
assert x.shape == mask.shape, "Input tensor and mask must have the same shape"
masked_x = torch.where(mask, x, torch.zeros_like(x))
# Create a ones kernel with the same number of channels as the input tensor
ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device)
# Perform sum pooling
sum_pooled = nn.functional.conv1d(
masked_x,
ones_kernel,
stride=self.stride,
padding=self.padding,
groups=x.size(1),
)
# Count the non-masked (valid) elements in each pooling window
valid_count = nn.functional.conv1d(
mask.float(),
ones_kernel,
stride=self.stride,
padding=self.padding,
groups=x.size(1),
)
valid_count = valid_count.clamp(min=1) # Avoid division by zero
# Perform masked average pooling
avg_pooled = sum_pooled / valid_count
# Fill zero values with NaNs
avg_pooled[avg_pooled == 0] = float("nan")
if ndim == 2:
return avg_pooled.squeeze(1)
return avg_pooled
class MaskedMedianPool1d(nn.Module):
def __init__(
self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0
):
"""An implementation of median pooling that supports masked values.
This implementation is inspired by the median pooling implementation in
https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598
Args:
kernel_size (int): The size of the median pooling window.
stride (int, optional): The stride of the median pooling window. Defaults to None.
padding (int, optional): The padding of the median pooling window. Defaults to 0.
"""
super(MaskedMedianPool1d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding
def forward(self, x, mask=None):
ndim = x.dim()
if ndim == 2:
x = x.unsqueeze(1)
assert (
x.dim() == 3
), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)"
if mask is None:
mask = ~torch.isnan(x)
assert x.shape == mask.shape, "Input tensor and mask must have the same shape"
masked_x = torch.where(mask, x, torch.zeros_like(x))
x = F.pad(masked_x, (self.padding, self.padding), mode="reflect")
mask = F.pad(
mask.float(), (self.padding, self.padding), mode="constant", value=0
)
x = x.unfold(2, self.kernel_size, self.stride)
mask = mask.unfold(2, self.kernel_size, self.stride)
x = x.contiguous().view(x.size()[:3] + (-1,))
mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device)
# Combine the mask with the input tensor
#x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf")))
x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device))
# Sort the masked tensor along the last dimension
x_sorted, _ = torch.sort(x_masked, dim=-1)
# Compute the count of non-masked (valid) values
valid_count = mask.sum(dim=-1)
# Calculate the index of the median value for each pooling window
median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0)
# Gather the median values using the calculated indices
median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
# Fill infinite values with NaNs
median_pooled[torch.isinf(median_pooled)] = float("nan")
if ndim == 2:
return median_pooled.squeeze(1)
return median_pooled
class CrepePitchExtractor(BasePitchExtractor):
def __init__(
self,
hop_length: int = 512,
f0_min: float = 50.0,
f0_max: float = 1100.0,
threshold: float = 0.05,
keep_zeros: bool = False,
device = None,
model: Literal["full", "tiny"] = "full",
use_fast_filters: bool = True,
decoder="viterbi"
):
super().__init__(hop_length, f0_min, f0_max, keep_zeros)
if decoder == "viterbi":
self.decoder = torchcrepe.decode.viterbi
elif decoder == "argmax":
self.decoder = torchcrepe.decode.argmax
elif decoder == "weighted_argmax":
self.decoder = torchcrepe.decode.weighted_argmax
else:
raise "Unknown decoder"
self.threshold = threshold
self.model = model
self.use_fast_filters = use_fast_filters
self.hop_length = hop_length
if device is None:
self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.dev = torch.device(device)
if self.use_fast_filters:
self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device)
self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device)
def __call__(self, x, sampling_rate=44100, pad_to=None):
"""Extract pitch using crepe.
Args:
x (torch.Tensor): Audio signal, shape (1, T).
sampling_rate (int, optional): Sampling rate. Defaults to 44100.
pad_to (int, optional): Pad to length. Defaults to None.
Returns:
torch.Tensor: Pitch, shape (T // hop_length,).
"""
assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor."
assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels."
x = x.to(self.dev)
f0, pd = torchcrepe.predict(
x,
sampling_rate,
self.hop_length,
self.f0_min,
self.f0_max,
pad=True,
model=self.model,
batch_size=1024,
device=x.device,
return_periodicity=True,
decoder=self.decoder
)
# Filter, remove silence, set uv threshold, refer to the original warehouse readme
if self.use_fast_filters:
pd = self.median_filter(pd)
else:
pd = torchcrepe.filter.median(pd, 3)
pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512)
f0 = torchcrepe.threshold.At(self.threshold)(f0, pd)
if self.use_fast_filters:
f0 = self.mean_filter(f0)
else:
f0 = torchcrepe.filter.mean(f0, 3)
f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0]
if torch.all(f0 == 0):
rtn = f0.cpu().numpy() if pad_to==None else np.zeros(pad_to)
return rtn,rtn
return self.post_process(x, sampling_rate, f0, pad_to)
|