lyimo's picture
Update app.py
95facbc verified
import gradio as gr
import torchaudio
import torch
import os
from pydub import AudioSegment
import tempfile
from speechbrain.pretrained.separation import SepformerSeparation
class AudioDenoiser:
def __init__(self):
# Initialize the SepFormer model for audio enhancement
self.model = SepformerSeparation.from_hparams(
source="speechbrain/sepformer-dns4-16k-enhancement",
savedir='pretrained_models/sepformer-dns4-16k-enhancement'
)
# Create output directory if it doesn't exist
os.makedirs("enhanced_audio", exist_ok=True)
def convert_audio_to_wav(self, input_path):
"""
Convert any audio format to WAV with proper settings
Args:
input_path (str): Path to input audio file
Returns:
str: Path to converted WAV file
"""
try:
# Create a temporary file for the converted audio
temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_wav_path = temp_wav.name
# Load audio using pydub (supports multiple formats)
audio = AudioSegment.from_file(input_path)
# Convert to mono if stereo
if audio.channels > 1:
audio = audio.set_channels(1)
# Export as WAV with proper settings
audio.export(
temp_wav_path,
format='wav',
parameters=[
'-ar', '16000', # Set sample rate to 16kHz
'-ac', '1' # Set channels to mono
]
)
return temp_wav_path
except Exception as e:
raise gr.Error(f"Error converting audio format: {str(e)}")
def enhance_audio(self, audio_path):
"""
Process the input audio file and return the enhanced version
Args:
audio_path (str): Path to the input audio file
Returns:
str: Path to the enhanced audio file
"""
try:
# Convert input audio to proper WAV format
wav_path = self.convert_audio_to_wav(audio_path)
# Separate and enhance the audio
est_sources = self.model.separate_file(path=wav_path)
# Generate output filename
output_path = os.path.join("enhanced_audio", "enhanced_audio.wav")
# Save the enhanced audio
torchaudio.save(
output_path,
est_sources[:, :, 0].detach().cpu(),
16000 # Sample rate
)
# Clean up temporary file
os.unlink(wav_path)
return output_path
except Exception as e:
raise gr.Error(f"Error processing audio: {str(e)}")
def create_gradio_interface():
# Initialize the denoiser
denoiser = AudioDenoiser()
# Create the Gradio interface
interface = gr.Interface(
fn=denoiser.enhance_audio,
inputs=gr.Audio(
type="filepath",
label="Upload Noisy Audio"
),
outputs=gr.Audio(
label="Enhanced Audio",
type="filepath"
),
title="Audio Denoising using SepFormer",
description="""
This application uses the SepFormer model from SpeechBrain to enhance audio quality
by removing background noise. Supports various audio formats including MP3 and WAV.
""",
article="""
Supported audio formats:
- MP3
- WAV
- OGG
- FLAC
- M4A
and more...
The audio will automatically be converted to the correct format for processing.
"""
)
return interface
if __name__ == "__main__":
# Create and launch the interface
demo = create_gradio_interface()
demo.launch()