tykiww's picture
Update services/asr.py
18a0c49 verified
raw
history blame
810 Bytes
import torch
from transformers import pipeline
class Transcriber:
def __init__(self, conf):
self.conf = conf
self.pipeline = self.asr_pipeline()
def asr_pipeline(self):
return pipeline(
self.conf["model"]["asr"]["type"],
model=self.conf["model"]["asr"]["transcriber"],
device=0 if torch.cuda.is_available() else -1 # Use 0 for GPU, -1 for CPU
)
def get_pipeline(self):
return self.pipeline
def run(self, file_path):
kwargs = {"max_new_tokens": self.conf["model"]["asr"]["max_new_tokens"]}
output = self.pipeline(
file_path,
generate_kwargs=kwargs,
return_timestamps=True,
)
return output.get("chunks", output) # Use .get to avoid key errors