File size: 4,033 Bytes
69e8a46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os

os.environ["MODELSCOPE_CACHE"] = ".cache/"

import string
import time
from threading import Lock

import librosa
import numpy as np
import opencc
import torch
from faster_whisper import WhisperModel

t2s_converter = opencc.OpenCC("t2s")


def load_model(*, device="cuda"):
    model = WhisperModel(
        "medium",
        device=device,
        compute_type="float16",
        download_root="faster_whisper",
    )
    print("faster_whisper loaded!")
    return model


@torch.no_grad()
def batch_asr_internal(model: WhisperModel, audios, sr):
    resampled_audios = []
    for audio in audios:

        if isinstance(audio, np.ndarray):
            audio = torch.from_numpy(audio).float()

        if audio.dim() > 1:
            audio = audio.squeeze()

        assert audio.dim() == 1
        audio_np = audio.numpy()
        resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
        resampled_audios.append(resampled_audio)

    trans_results = []

    for resampled_audio in resampled_audios:
        segments, info = model.transcribe(
            resampled_audio,
            language=None,
            beam_size=5,
            initial_prompt="Punctuation is needed in any language.",
        )
        trans_results.append(list(segments))

    results = []
    for trans_res, audio in zip(trans_results, audios):

        duration = len(audio) / sr * 1000
        huge_gap = False
        max_gap = 0.0

        text = None
        last_tr = None

        for tr in trans_res:
            delta = tr.text.strip()
            if tr.id > 1:
                max_gap = max(tr.start - last_tr.end, max_gap)
                text += delta
            else:
                text = delta

            last_tr = tr
            if max_gap > 3.0:
                huge_gap = True
                break

        sim_text = t2s_converter.convert(text)
        results.append(
            {
                "text": sim_text,
                "duration": duration,
                "huge_gap": huge_gap,
            }
        )

    return results


global_lock = Lock()


def batch_asr(model, audios, sr):
    return batch_asr_internal(model, audios, sr)


def is_chinese(text):
    return True


def calculate_wer(text1, text2, debug=False):
    chars1 = remove_punctuation(text1)
    chars2 = remove_punctuation(text2)

    m, n = len(chars1), len(chars2)

    if m > n:
        chars1, chars2 = chars2, chars1
        m, n = n, m

    prev = list(range(m + 1))  # row 0 distance: [0, 1, 2, ...]
    curr = [0] * (m + 1)

    for j in range(1, n + 1):
        curr[0] = j
        for i in range(1, m + 1):
            if chars1[i - 1] == chars2[j - 1]:
                curr[i] = prev[i - 1]
            else:
                curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
        prev, curr = curr, prev

    edits = prev[m]
    tot = max(len(chars1), len(chars2))
    wer = edits / tot

    if debug:
        print("            gt:   ", chars1)
        print("          pred:   ", chars2)
        print(" edits/tot = wer: ", edits, "/", tot, "=", wer)

    return wer


def remove_punctuation(text):
    chinese_punctuation = (
        " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
        '‛""„‟…‧﹏'
    )
    all_punctuation = string.punctuation + chinese_punctuation
    translator = str.maketrans("", "", all_punctuation)
    text_without_punctuation = text.translate(translator)
    return text_without_punctuation


if __name__ == "__main__":
    model = load_model()
    audios = [
        librosa.load("44100.wav", sr=44100)[0],
        librosa.load("lengyue.wav", sr=44100)[0],
    ]
    print(np.array(audios[0]))
    print(batch_asr(model, audios, 44100))

    start_time = time.time()
    for _ in range(10):
        print(batch_asr(model, audios, 44100))
    print("Time taken:", time.time() - start_time)