tykiww's picture
Update services/diarization.py
e29651b verified
raw
history blame
2.43 kB
import os
import torch
from pyannote.audio import Pipeline
def extract_files(files):
filepaths = [file.name for file in files]
return filepaths
class Diarizer:
def __init__(self, conf):
self.conf = conf
self.pipeline = self.pyannote_pipeline()
def pyannote_pipeline(self):
pipeline = Pipeline.from_pretrained(
self.conf["model"]["diarizer"],
use_auth_token=os.environ["HUGGINGFACE_TOKEN"]
)
return pipeline
def get_pipeline(self):
return self.pipeline
def add_device(self, pipeline):
"""Offloaded to allow for best timing when working with GPUs"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
return pipeline
def diarize_audio(self, temp_file, num_speakers):
pipeline = self.add_device(self.pipeline)
diarization = pipeline(temp_file, num_speakers=num_speakers)
# os.remove(temp_file) # Uncomment if you want to remove the temp file after processing
return str(diarization)
def extract_seconds(self, timestamp):
h, m, s = map(float, timestamp.split(':'))
return 3600 * h + 60 * m + s
def generate_labels_from_diarization(self, diarized_output):
labels_path = 'labels.txt'
lines = diarized_output.strip().split('\n')
plaintext = ""
for line in lines:
try:
parts = line.strip()[1:-1].split(' --> ')
if len(parts) == 2:
label = line.split()[-1].strip()
start_seconds = self.extract_seconds(parts[0].strip())
end_seconds = self.extract_seconds(parts[1].split(']')[0].strip())
plaintext += f"{start_seconds}\t{end_seconds}\t{label}\n"
else:
raise ValueError("Unexpected format in diarized output")
except Exception as e:
print(f"Error processing line: '{line.strip()}'. Error: {e}")
with open(labels_path, "w") as file:
file.write(plaintext)
return labels_path
def run(self, temp_file, num_speakers):
diarization_result = self.diarize_audio(temp_file, num_speakers)
label_file = self.generate_labels_from_diarization(diarization_result)
return diarization_result, label_file