pdjdev's picture
add ddsp-svc
85a7d2c
raw
history blame
11.4 kB
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
def MaskedAvgPool1d(x, kernel_size):
x = x.unsqueeze(1)
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
mask = ~torch.isnan(x)
masked_x = torch.where(mask, x, torch.zeros_like(x))
ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device)
# Perform sum pooling
sum_pooled = F.conv1d(
masked_x,
ones_kernel,
stride=1,
padding=0,
groups=x.size(1),
)
# Count the non-masked (valid) elements in each pooling window
valid_count = F.conv1d(
mask.float(),
ones_kernel,
stride=1,
padding=0,
groups=x.size(1),
)
valid_count = valid_count.clamp(min=1) # Avoid division by zero
# Perform masked average pooling
avg_pooled = sum_pooled / valid_count
return avg_pooled.squeeze(1)
def MedianPool1d(x, kernel_size):
x = x.unsqueeze(1)
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
x = x.squeeze(1)
x = x.unfold(1, kernel_size, 1)
x, _ = torch.sort(x, dim=-1)
return x[:, :, (kernel_size - 1) // 2]
def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
"""Calculate final size for efficient FFT.
Args:
frame_size: Size of the audio frame.
ir_size: Size of the convolving impulse response.
power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
numbers. TPU requires power of 2, while GPU is more flexible.
Returns:
fft_size: Size for efficient FFT.
"""
convolved_frame_size = ir_size + frame_size - 1
if power_of_2:
# Next power of 2.
fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
else:
fft_size = convolved_frame_size
return fft_size
def upsample(signal, factor):
signal = signal.permute(0, 2, 1)
signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True)
signal = signal[:,:,:-1]
return signal.permute(0, 2, 1)
def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
n_harm = amplitudes.shape[-1]
pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
aa = (pitches < fmax).float() + 1e-7
return amplitudes * aa
def crop_and_compensate_delay(audio, audio_size, ir_size,
padding = 'same',
delay_compensation = -1):
"""Crop audio output from convolution to compensate for group delay.
Args:
audio: Audio after convolution. Tensor of shape [batch, time_steps].
audio_size: Initial size of the audio before convolution.
ir_size: Size of the convolving impulse response.
padding: Either 'valid' or 'same'. For 'same' the final output to be the
same size as the input audio (audio_timesteps). For 'valid' the audio is
extended to include the tail of the impulse response (audio_timesteps +
ir_timesteps - 1).
delay_compensation: Samples to crop from start of output audio to compensate
for group delay of the impulse response. If delay_compensation < 0 it
defaults to automatically calculating a constant group delay of the
windowed linear phase filter from frequency_impulse_response().
Returns:
Tensor of cropped and shifted audio.
Raises:
ValueError: If padding is not either 'valid' or 'same'.
"""
# Crop the output.
if padding == 'valid':
crop_size = ir_size + audio_size - 1
elif padding == 'same':
crop_size = audio_size
else:
raise ValueError('Padding must be \'valid\' or \'same\', instead '
'of {}.'.format(padding))
# Compensate for the group delay of the filter by trimming the front.
# For an impulse response produced by frequency_impulse_response(),
# the group delay is constant because the filter is linear phase.
total_size = int(audio.shape[-1])
crop = total_size - crop_size
start = (ir_size // 2 if delay_compensation < 0 else delay_compensation)
end = crop - start
return audio[:, start:-end]
def fft_convolve(audio,
impulse_response): # B, n_frames, 2*(n_mags-1)
"""Filter audio with frames of time-varying impulse responses.
Time-varying filter. Given audio [batch, n_samples], and a series of impulse
responses [batch, n_frames, n_impulse_response], splits the audio into frames,
applies filters, and then overlap-and-adds audio back together.
Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
convolution for large impulse response sizes.
Args:
audio: Input audio. Tensor of shape [batch, audio_timesteps].
impulse_response: Finite impulse response to convolve. Can either be a 2-D
Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
ir_frames, ir_size]. A 2-D tensor will apply a single linear
time-invariant filter to the audio. A 3-D Tensor will apply a linear
time-varying filter. Automatically chops the audio into equally shaped
blocks to match ir_frames.
Returns:
audio_out: Convolved audio. Tensor of shape
[batch, audio_timesteps].
"""
# Add a frame dimension to impulse response if it doesn't have one.
ir_shape = impulse_response.size()
if len(ir_shape) == 2:
impulse_response = impulse_response.unsqueeze(1)
ir_shape = impulse_response.size()
# Get shapes of audio and impulse response.
batch_size_ir, n_ir_frames, ir_size = ir_shape
batch_size, audio_size = audio.size() # B, T
# Validate that batch sizes match.
if batch_size != batch_size_ir:
raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
'be the same.'.format(batch_size, batch_size_ir))
# Cut audio into 50% overlapped frames (center padding).
hop_size = int(audio_size / n_ir_frames)
frame_size = 2 * hop_size
audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size)
# Apply Bartlett (triangular) window
window = torch.bartlett_window(frame_size).to(audio_frames)
audio_frames = audio_frames * window
# Pad and FFT the audio and impulse responses.
fft_size = get_fft_size(frame_size, ir_size, power_of_2=False)
audio_fft = torch.fft.rfft(audio_frames, fft_size)
ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size)
# Multiply the FFTs (same as convolution in time).
audio_ir_fft = torch.multiply(audio_fft, ir_fft)
# Take the IFFT to resynthesize audio.
audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size)
# Overlap Add
batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1
fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size))
output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1)
# Crop and shift the output audio.
output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
return output_signal
def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
window_size: int = 0,
causal: bool = False):
"""Apply a window to an impulse response and put in causal form.
Args:
impulse_response: A series of impulse responses frames to window, of shape
[batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
window_size: Size of the window to apply in the time domain. If window_size
is less than 1, it defaults to the impulse_response size.
causal: Impulse response input is in causal form (peak in the middle).
Returns:
impulse_response: Windowed impulse response in causal form, with last
dimension cropped to window_size if window_size is greater than 0 and less
than ir_size.
"""
# If IR is in causal form, put it in zero-phase form.
if causal:
impulse_response = torch.fftshift(impulse_response, axes=-1)
# Get a window for better time/frequency resolution than rectangular.
# Window defaults to IR size, cannot be bigger.
ir_size = int(impulse_response.size(-1))
if (window_size <= 0) or (window_size > ir_size):
window_size = ir_size
window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
# Zero pad the window and put in in zero-phase form.
padding = ir_size - window_size
if padding > 0:
half_idx = (window_size + 1) // 2
window = torch.cat([window[half_idx:],
torch.zeros([padding]),
window[:half_idx]], axis=0)
else:
window = window.roll(window.size(-1)//2, -1)
# Apply the window, to get new IR (both in zero-phase form).
window = window.unsqueeze(0)
impulse_response = impulse_response * window
# Put IR in causal form and trim zero padding.
if padding > 0:
first_half_start = (ir_size - (half_idx - 1)) + 1
second_half_end = half_idx + 1
impulse_response = torch.cat([impulse_response[..., first_half_start:],
impulse_response[..., :second_half_end]],
dim=-1)
else:
impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1)
return impulse_response
def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
half_width_frames): # B,n_frames, 1
ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1
window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
window[window > 1] = 0
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
impulse_response = impulse_response.roll(ir_size // 2, -1)
impulse_response = impulse_response * window
return impulse_response
def frequency_impulse_response(magnitudes,
hann_window = True,
half_width_frames = None):
# Get the IR
impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
# Window and put in causal form.
if hann_window:
if half_width_frames is None:
impulse_response = apply_window_to_impulse_response(impulse_response)
else:
impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
else:
impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1)
return impulse_response
def frequency_filter(audio,
magnitudes,
hann_window=True,
half_width_frames=None):
impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames)
return fft_convolve(audio, impulse_response)