spellingdragon commited on
Commit
3119948
1 Parent(s): efbf3eb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +41 -9
handler.py CHANGED
@@ -1,21 +1,53 @@
1
- from typing import Dict, Any
2
- from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer, pipeline
 
 
3
 
4
  class EndpointHandler():
5
  def __init__(self, path=""):
 
 
6
  model_id = "openai/whisper-large-v3"
7
- task = "automatic-speech-recognition"
8
- self.model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
9
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
10
- self.pipeline = pipeline(task, model=self.model, tokenizer=self.tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
 
 
 
 
 
 
 
 
 
13
  inputs = data.pop("inputs", data)
14
  parameters = data.pop("parameters", None)
15
 
 
16
  if parameters is not None:
17
- result = self.pipeline(inputs, return_timestamps=True, **parameters)
18
  else:
19
- result = self.pipeline(inputs, return_timestamps=True)
20
-
21
  return {"chunks": result["chunks"]}
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ from transformers.pipelines.audio_utils import ffmpeg_read
4
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
  model_id = "openai/whisper-large-v3"
11
+
12
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
13
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
14
+ )
15
+ model.to(device)
16
+
17
+ processor = AutoProcessor.from_pretrained(model_id)
18
+
19
+ self.pipeline = pipeline(
20
+ "automatic-speech-recognition",
21
+ model=model,
22
+ tokenizer=processor.tokenizer,
23
+ feature_extractor=processor.feature_extractor,
24
+ max_new_tokens=128,
25
+ chunk_length_s=30,
26
+ batch_size=16,
27
+ return_timestamps=True,
28
+ torch_dtype=torch_dtype,
29
+ device=device,
30
+ )
31
+ self.model = model
32
+
33
 
34
  def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
35
+ """
36
+ Args:
37
+ data (:obj:):
38
+ includes the input data and the parameters for the inference.
39
+ Return:
40
+ A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
41
+ - "label": A string representing what the label/class is. There can be multiple labels.
42
+ - "score": A score between 0 and 1 describing how confident the model is for this label/class.
43
+ """
44
  inputs = data.pop("inputs", data)
45
  parameters = data.pop("parameters", None)
46
 
47
+ # pass inputs with all kwargs in data
48
  if parameters is not None:
49
+ result = self.pipeline(inputs, return_timestamps=True, **parameters)
50
  else:
51
+ result = self.pipeline(inputs, return_timestamps=True, generate_kwargs={"task": "translate"})
52
+ # postprocess the prediction
53
  return {"chunks": result["chunks"]}