lyimo's picture
Update app.py
95facbc verified
raw
history blame
4.06 kB
import gradio as gr
import torchaudio
import torch
import os
from pydub import AudioSegment
import tempfile
from speechbrain.pretrained.separation import SepformerSeparation
class AudioDenoiser:
def __init__(self):
# Initialize the SepFormer model for audio enhancement
self.model = SepformerSeparation.from_hparams(
source="speechbrain/sepformer-dns4-16k-enhancement",
savedir='pretrained_models/sepformer-dns4-16k-enhancement'
)
# Create output directory if it doesn't exist
os.makedirs("enhanced_audio", exist_ok=True)
def convert_audio_to_wav(self, input_path):
"""
Convert any audio format to WAV with proper settings
Args:
input_path (str): Path to input audio file
Returns:
str: Path to converted WAV file
"""
try:
# Create a temporary file for the converted audio
temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_wav_path = temp_wav.name
# Load audio using pydub (supports multiple formats)
audio = AudioSegment.from_file(input_path)
# Convert to mono if stereo
if audio.channels > 1:
audio = audio.set_channels(1)
# Export as WAV with proper settings
audio.export(
temp_wav_path,
format='wav',
parameters=[
'-ar', '16000', # Set sample rate to 16kHz
'-ac', '1' # Set channels to mono
]
)
return temp_wav_path
except Exception as e:
raise gr.Error(f"Error converting audio format: {str(e)}")
def enhance_audio(self, audio_path):
"""
Process the input audio file and return the enhanced version
Args:
audio_path (str): Path to the input audio file
Returns:
str: Path to the enhanced audio file
"""
try:
# Convert input audio to proper WAV format
wav_path = self.convert_audio_to_wav(audio_path)
# Separate and enhance the audio
est_sources = self.model.separate_file(path=wav_path)
# Generate output filename
output_path = os.path.join("enhanced_audio", "enhanced_audio.wav")
# Save the enhanced audio
torchaudio.save(
output_path,
est_sources[:, :, 0].detach().cpu(),
16000 # Sample rate
)
# Clean up temporary file
os.unlink(wav_path)
return output_path
except Exception as e:
raise gr.Error(f"Error processing audio: {str(e)}")
def create_gradio_interface():
# Initialize the denoiser
denoiser = AudioDenoiser()
# Create the Gradio interface
interface = gr.Interface(
fn=denoiser.enhance_audio,
inputs=gr.Audio(
type="filepath",
label="Upload Noisy Audio"
),
outputs=gr.Audio(
label="Enhanced Audio",
type="filepath"
),
title="Audio Denoising using SepFormer",
description="""
This application uses the SepFormer model from SpeechBrain to enhance audio quality
by removing background noise. Supports various audio formats including MP3 and WAV.
""",
article="""
Supported audio formats:
- MP3
- WAV
- OGG
- FLAC
- M4A
and more...
The audio will automatically be converted to the correct format for processing.
"""
)
return interface
if __name__ == "__main__":
# Create and launch the interface
demo = create_gradio_interface()
demo.launch()