Spaces:
Running
on
Zero
Running
on
Zero
import numpy | |
import pyloudnorm as pyln | |
import torch | |
from torchaudio.transforms import MelSpectrogram | |
from torchaudio.transforms import Resample | |
class AudioPreprocessor: | |
def __init__(self, input_sr, output_sr=None, cut_silence=False, do_loudnorm=False, device="cpu"): | |
""" | |
The parameters are by default set up to do well | |
on a 16kHz signal. A different sampling rate may | |
require different hop_length and n_fft (e.g. | |
doubling frequency --> doubling hop_length and | |
doubling n_fft) | |
""" | |
self.cut_silence = cut_silence | |
self.do_loudnorm = do_loudnorm | |
self.device = device | |
self.input_sr = input_sr | |
self.output_sr = output_sr | |
self.meter = pyln.Meter(input_sr) | |
self.final_sr = input_sr | |
self.wave_to_spectrogram = LogMelSpec(output_sr if output_sr is not None else input_sr).to(device) | |
if cut_silence: | |
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround | |
# careful: assumes 16kHz or 8kHz audio | |
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
model='silero_vad', | |
force_reload=False, | |
onnx=False, | |
verbose=False) | |
(self.get_speech_timestamps, | |
self.save_audio, | |
self.read_audio, | |
self.VADIterator, | |
self.collect_chunks) = utils | |
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets | |
# this to false globally during model loading rather than using inference mode or no_grad | |
self.silero_model = self.silero_model.to(self.device) | |
if output_sr is not None and output_sr != input_sr: | |
self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device) | |
self.final_sr = output_sr | |
else: | |
self.resample = lambda x: x | |
def cut_leading_and_trailing_silence(self, audio): | |
""" | |
https://github.com/snakers4/silero-vad | |
""" | |
with torch.inference_mode(): | |
speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr) | |
try: | |
result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] | |
return result | |
except IndexError: | |
print("Audio might be too short to cut silences from front and back.") | |
return audio | |
def normalize_loudness(self, audio): | |
""" | |
normalize the amplitudes according to | |
their decibels, so this should turn any | |
signal with different magnitudes into | |
the same magnitude by analysing loudness | |
""" | |
try: | |
loudness = self.meter.integrated_loudness(audio) | |
except ValueError: | |
# if the audio is too short, a value error will arise | |
return audio | |
loud_normed = pyln.normalize.loudness(audio, loudness, -30.0) | |
peak = numpy.amax(numpy.abs(loud_normed)) | |
peak_normed = numpy.divide(loud_normed, peak) | |
return peak_normed | |
def normalize_audio(self, audio): | |
""" | |
one function to apply them all in an | |
order that makes sense. | |
""" | |
if self.do_loudnorm: | |
audio = self.normalize_loudness(audio) | |
audio = torch.tensor(audio, device=self.device, dtype=torch.float32) | |
audio = self.resample(audio) | |
if self.cut_silence: | |
audio = self.cut_leading_and_trailing_silence(audio) | |
return audio | |
def audio_to_mel_spec_tensor(self, audio, normalize=False, explicit_sampling_rate=None): | |
""" | |
explicit_sampling_rate is for when | |
normalization has already been applied | |
and that included resampling. No way | |
to detect the current input_sr of the incoming | |
audio | |
""" | |
if type(audio) != torch.tensor and type(audio) != torch.Tensor: | |
audio = torch.tensor(audio, device=self.device) | |
if explicit_sampling_rate is None or explicit_sampling_rate == self.output_sr: | |
return self.wave_to_spectrogram(audio.float()) | |
else: | |
if explicit_sampling_rate != self.input_sr: | |
print("WARNING: different sampling rate used, this will be very slow if it happens often. Consider creating a dedicated audio processor.") | |
self.resample = Resample(orig_freq=explicit_sampling_rate, new_freq=self.output_sr).to(self.device) | |
self.input_sr = explicit_sampling_rate | |
audio = self.resample(audio.float()) | |
return self.wave_to_spectrogram(audio) | |
class LogMelSpec(torch.nn.Module): | |
def __init__(self, sr, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.spec = MelSpectrogram(sample_rate=sr, | |
n_fft=1024, | |
win_length=1024, | |
hop_length=256, | |
f_min=40.0, | |
f_max=sr // 2, | |
pad=0, | |
n_mels=128, | |
power=2.0, | |
normalized=False, | |
center=True, | |
pad_mode='reflect', | |
mel_scale='htk') | |
def forward(self, audio): | |
melspec = self.spec(audio.float()) | |
zero_mask = melspec == 0 | |
melspec[zero_mask] = 1e-8 | |
logmelspec = torch.log10(melspec) | |
return logmelspec | |
if __name__ == '__main__': | |
import soundfile | |
wav, sr = soundfile.read("../audios/ad00_0004.wav") | |
ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True) | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 6)) | |
import librosa.display as lbd | |
lbd.specshow(ap.audio_to_mel_spec_tensor(wav).cpu().numpy(), | |
ax=ax, | |
sr=16000, | |
cmap='GnBu', | |
y_axis='features', | |
x_axis=None, | |
hop_length=256) | |
plt.show() | |