File size: 11,432 Bytes
85a7d2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
def MaskedAvgPool1d(x, kernel_size):
x = x.unsqueeze(1)
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
mask = ~torch.isnan(x)
masked_x = torch.where(mask, x, torch.zeros_like(x))
ones_kernel = torch.ones(x.size(1), 1, kernel_size, device=x.device)
# Perform sum pooling
sum_pooled = F.conv1d(
masked_x,
ones_kernel,
stride=1,
padding=0,
groups=x.size(1),
)
# Count the non-masked (valid) elements in each pooling window
valid_count = F.conv1d(
mask.float(),
ones_kernel,
stride=1,
padding=0,
groups=x.size(1),
)
valid_count = valid_count.clamp(min=1) # Avoid division by zero
# Perform masked average pooling
avg_pooled = sum_pooled / valid_count
return avg_pooled.squeeze(1)
def MedianPool1d(x, kernel_size):
x = x.unsqueeze(1)
x = F.pad(x, ((kernel_size - 1) // 2, kernel_size // 2), mode="reflect")
x = x.squeeze(1)
x = x.unfold(1, kernel_size, 1)
x, _ = torch.sort(x, dim=-1)
return x[:, :, (kernel_size - 1) // 2]
def get_fft_size(frame_size: int, ir_size: int, power_of_2: bool = True):
"""Calculate final size for efficient FFT.
Args:
frame_size: Size of the audio frame.
ir_size: Size of the convolving impulse response.
power_of_2: Constrain to be a power of 2. If False, allow other 5-smooth
numbers. TPU requires power of 2, while GPU is more flexible.
Returns:
fft_size: Size for efficient FFT.
"""
convolved_frame_size = ir_size + frame_size - 1
if power_of_2:
# Next power of 2.
fft_size = int(2**np.ceil(np.log2(convolved_frame_size)))
else:
fft_size = convolved_frame_size
return fft_size
def upsample(signal, factor):
signal = signal.permute(0, 2, 1)
signal = nn.functional.interpolate(torch.cat((signal,signal[:,:,-1:]),2), size=signal.shape[-1] * factor + 1, mode='linear', align_corners=True)
signal = signal[:,:,:-1]
return signal.permute(0, 2, 1)
def remove_above_fmax(amplitudes, pitch, fmax, level_start=1):
n_harm = amplitudes.shape[-1]
pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
aa = (pitches < fmax).float() + 1e-7
return amplitudes * aa
def crop_and_compensate_delay(audio, audio_size, ir_size,
padding = 'same',
delay_compensation = -1):
"""Crop audio output from convolution to compensate for group delay.
Args:
audio: Audio after convolution. Tensor of shape [batch, time_steps].
audio_size: Initial size of the audio before convolution.
ir_size: Size of the convolving impulse response.
padding: Either 'valid' or 'same'. For 'same' the final output to be the
same size as the input audio (audio_timesteps). For 'valid' the audio is
extended to include the tail of the impulse response (audio_timesteps +
ir_timesteps - 1).
delay_compensation: Samples to crop from start of output audio to compensate
for group delay of the impulse response. If delay_compensation < 0 it
defaults to automatically calculating a constant group delay of the
windowed linear phase filter from frequency_impulse_response().
Returns:
Tensor of cropped and shifted audio.
Raises:
ValueError: If padding is not either 'valid' or 'same'.
"""
# Crop the output.
if padding == 'valid':
crop_size = ir_size + audio_size - 1
elif padding == 'same':
crop_size = audio_size
else:
raise ValueError('Padding must be \'valid\' or \'same\', instead '
'of {}.'.format(padding))
# Compensate for the group delay of the filter by trimming the front.
# For an impulse response produced by frequency_impulse_response(),
# the group delay is constant because the filter is linear phase.
total_size = int(audio.shape[-1])
crop = total_size - crop_size
start = (ir_size // 2 if delay_compensation < 0 else delay_compensation)
end = crop - start
return audio[:, start:-end]
def fft_convolve(audio,
impulse_response): # B, n_frames, 2*(n_mags-1)
"""Filter audio with frames of time-varying impulse responses.
Time-varying filter. Given audio [batch, n_samples], and a series of impulse
responses [batch, n_frames, n_impulse_response], splits the audio into frames,
applies filters, and then overlap-and-adds audio back together.
Applies non-windowed non-overlapping STFT/ISTFT to efficiently compute
convolution for large impulse response sizes.
Args:
audio: Input audio. Tensor of shape [batch, audio_timesteps].
impulse_response: Finite impulse response to convolve. Can either be a 2-D
Tensor of shape [batch, ir_size], or a 3-D Tensor of shape [batch,
ir_frames, ir_size]. A 2-D tensor will apply a single linear
time-invariant filter to the audio. A 3-D Tensor will apply a linear
time-varying filter. Automatically chops the audio into equally shaped
blocks to match ir_frames.
Returns:
audio_out: Convolved audio. Tensor of shape
[batch, audio_timesteps].
"""
# Add a frame dimension to impulse response if it doesn't have one.
ir_shape = impulse_response.size()
if len(ir_shape) == 2:
impulse_response = impulse_response.unsqueeze(1)
ir_shape = impulse_response.size()
# Get shapes of audio and impulse response.
batch_size_ir, n_ir_frames, ir_size = ir_shape
batch_size, audio_size = audio.size() # B, T
# Validate that batch sizes match.
if batch_size != batch_size_ir:
raise ValueError('Batch size of audio ({}) and impulse response ({}) must '
'be the same.'.format(batch_size, batch_size_ir))
# Cut audio into 50% overlapped frames (center padding).
hop_size = int(audio_size / n_ir_frames)
frame_size = 2 * hop_size
audio_frames = F.pad(audio, (hop_size, hop_size)).unfold(1, frame_size, hop_size)
# Apply Bartlett (triangular) window
window = torch.bartlett_window(frame_size).to(audio_frames)
audio_frames = audio_frames * window
# Pad and FFT the audio and impulse responses.
fft_size = get_fft_size(frame_size, ir_size, power_of_2=False)
audio_fft = torch.fft.rfft(audio_frames, fft_size)
ir_fft = torch.fft.rfft(torch.cat((impulse_response,impulse_response[:,-1:,:]),1), fft_size)
# Multiply the FFTs (same as convolution in time).
audio_ir_fft = torch.multiply(audio_fft, ir_fft)
# Take the IFFT to resynthesize audio.
audio_frames_out = torch.fft.irfft(audio_ir_fft, fft_size)
# Overlap Add
batch_size, n_audio_frames, frame_size = audio_frames_out.size() # # B, n_frames+1, 2*(hop_size+n_mags-1)-1
fold = torch.nn.Fold(output_size=(1, (n_audio_frames - 1) * hop_size + frame_size),kernel_size=(1, frame_size),stride=(1, hop_size))
output_signal = fold(audio_frames_out.transpose(1, 2)).squeeze(1).squeeze(1)
# Crop and shift the output audio.
output_signal = crop_and_compensate_delay(output_signal[:,hop_size:], audio_size, ir_size)
return output_signal
def apply_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1)
window_size: int = 0,
causal: bool = False):
"""Apply a window to an impulse response and put in causal form.
Args:
impulse_response: A series of impulse responses frames to window, of shape
[batch, n_frames, ir_size]. ---------> ir_size means size of filter_bank ??????
window_size: Size of the window to apply in the time domain. If window_size
is less than 1, it defaults to the impulse_response size.
causal: Impulse response input is in causal form (peak in the middle).
Returns:
impulse_response: Windowed impulse response in causal form, with last
dimension cropped to window_size if window_size is greater than 0 and less
than ir_size.
"""
# If IR is in causal form, put it in zero-phase form.
if causal:
impulse_response = torch.fftshift(impulse_response, axes=-1)
# Get a window for better time/frequency resolution than rectangular.
# Window defaults to IR size, cannot be bigger.
ir_size = int(impulse_response.size(-1))
if (window_size <= 0) or (window_size > ir_size):
window_size = ir_size
window = nn.Parameter(torch.hann_window(window_size), requires_grad = False).to(impulse_response)
# Zero pad the window and put in in zero-phase form.
padding = ir_size - window_size
if padding > 0:
half_idx = (window_size + 1) // 2
window = torch.cat([window[half_idx:],
torch.zeros([padding]),
window[:half_idx]], axis=0)
else:
window = window.roll(window.size(-1)//2, -1)
# Apply the window, to get new IR (both in zero-phase form).
window = window.unsqueeze(0)
impulse_response = impulse_response * window
# Put IR in causal form and trim zero padding.
if padding > 0:
first_half_start = (ir_size - (half_idx - 1)) + 1
second_half_end = half_idx + 1
impulse_response = torch.cat([impulse_response[..., first_half_start:],
impulse_response[..., :second_half_end]],
dim=-1)
else:
impulse_response = impulse_response.roll(impulse_response.size(-1)//2, -1)
return impulse_response
def apply_dynamic_window_to_impulse_response(impulse_response, # B, n_frames, 2*(n_mag-1) or 2*n_mag-1
half_width_frames): # B,n_frames, 1
ir_size = int(impulse_response.size(-1)) # 2*(n_mag -1) or 2*n_mag-1
window = torch.arange(-(ir_size // 2), (ir_size + 1) // 2).to(impulse_response) / half_width_frames
window[window > 1] = 0
window = (1 + torch.cos(np.pi * window)) / 2 # B, n_frames, 2*(n_mag -1) or 2*n_mag-1
impulse_response = impulse_response.roll(ir_size // 2, -1)
impulse_response = impulse_response * window
return impulse_response
def frequency_impulse_response(magnitudes,
hann_window = True,
half_width_frames = None):
# Get the IR
impulse_response = torch.fft.irfft(magnitudes) # B, n_frames, 2*(n_mags-1)
# Window and put in causal form.
if hann_window:
if half_width_frames is None:
impulse_response = apply_window_to_impulse_response(impulse_response)
else:
impulse_response = apply_dynamic_window_to_impulse_response(impulse_response, half_width_frames)
else:
impulse_response = impulse_response.roll(impulse_response.size(-1) // 2, -1)
return impulse_response
def frequency_filter(audio,
magnitudes,
hann_window=True,
half_width_frames=None):
impulse_response = frequency_impulse_response(magnitudes, hann_window, half_width_frames)
return fft_convolve(audio, impulse_response)
|