akera commited on
Commit
27b508c
1 Parent(s): 6ce3643

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -24,37 +24,42 @@ def format_and_group_timestamps(chunks, interval=5.0):
24
  if interval_start not in grouped:
25
  grouped[interval_start] = []
26
  grouped[interval_start].append((start, end, word))
27
-
28
  formatted_output = f"Transcript: {transcript.strip()}'\n\n-------\n\nword-stamped transcripts (every 5 seconds):\n\n"
29
  for interval_start, words in grouped.items():
30
  formatted_output += f"({interval_start}, {interval_start + interval}) -- {' '.join([w[2] for w in words])}\n"
31
  return formatted_output
32
 
33
- # Modified transcribe_audio function
34
  def transcribe_audio(input_file, language, chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
  target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
37
  target_lang_code = target_lang_options[language]
38
-
39
- # Determine the model_id based on the language
40
  if target_lang_code == "eng":
41
- model_id = "facebook/mms-1b-all"
 
 
 
 
42
  else:
 
 
43
  model_id = "Sunbird/sunbird-mms"
44
-
45
- pipe = pipeline(model=model_id, device=device, token=auth_token)
46
- pipe.tokenizer.set_target_lang(target_lang_code)
47
- pipe.model.load_adapter(target_lang_code)
48
 
49
- output = pipe(input_file, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
50
- formatted_output = format_and_group_timestamps(output['chunks'])
51
- return formatted_output
 
 
 
 
 
52
 
53
  # Interface setup remains the same
54
  description = '''ASR with salt-mms'''
55
  iface = gr.Interface(fn=transcribe_audio,
56
  inputs=[
57
- gr.Audio(source="upload", type="filepath", label="upload file to transcribe"),
58
  gr.Dropdown(choices=list(target_lang_options.keys()), label="Language", value="English")
59
  ],
60
  outputs=gr.Textbox(label="Transcription"),
 
24
  if interval_start not in grouped:
25
  grouped[interval_start] = []
26
  grouped[interval_start].append((start, end, word))
27
+
28
  formatted_output = f"Transcript: {transcript.strip()}'\n\n-------\n\nword-stamped transcripts (every 5 seconds):\n\n"
29
  for interval_start, words in grouped.items():
30
  formatted_output += f"({interval_start}, {interval_start + interval}) -- {' '.join([w[2] for w in words])}\n"
31
  return formatted_output
32
 
33
+ # Modified transcribe_audio function to use Whisper for English
34
  def transcribe_audio(input_file, language, chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
 
35
  target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
36
  target_lang_code = target_lang_options[language]
37
+
 
38
  if target_lang_code == "eng":
39
+ # Use Whisper for English
40
+ model = whisper.load_model("small")
41
+ result = model.transcribe(input_file)
42
+ # Assuming you want to keep the formatting function for consistency
43
+ return result["text"]
44
  else:
45
+ # Use specified model for other languages
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
  model_id = "Sunbird/sunbird-mms"
 
 
 
 
48
 
49
+ auth_token = os.environ.get("HF_TOKEN")
50
+ pipe = pipeline(model=model_id, device=device, token=auth_token)
51
+ pipe.tokenizer.set_target_lang(target_lang_code)
52
+ pipe.model.load_adapter(target_lang_code)
53
+
54
+ output = pipe(input_file, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, return_timestamps=return_timestamps)
55
+ formatted_output = format_and_group_timestamps(output['chunks'])
56
+ return formatted_output
57
 
58
  # Interface setup remains the same
59
  description = '''ASR with salt-mms'''
60
  iface = gr.Interface(fn=transcribe_audio,
61
  inputs=[
62
+ gr.Audio(sources="upload", type="filepath", label="upload file to transcribe"),
63
  gr.Dropdown(choices=list(target_lang_options.keys()), label="Language", value="English")
64
  ],
65
  outputs=gr.Textbox(label="Transcription"),