File size: 2,031 Bytes
7838411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
from pathlib import Path
import jiwer
import pdb
import torch.nn as nn
import torch
import torchaudio
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC

import yaml
import librosa
import librosa.display
import matplotlib.pyplot as plt
import soundfile as sf


def TOKENLIZER(audio_path):
    token_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
    tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
    feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    # # load first sample of English common_voice
    # dataset = load_dataset("common_voice", "en", split="train", streaming=True)
    # dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
    # dataset_iter = iter(dataset)
    # sample = next(dataset_iter)

    # # forward sample through model to get greedily predicted transcription ids
    # input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
    # pdb.set_trace()
    
    # load samples
    input_values, sr = torchaudio.load(audio_path)
    # resample
    if sr != feature_extractor.sampling_rate:
        input_values = torchaudio.functional.resample(input_values, sr, feature_extractor.sampling_rate)

    logits = token_model(input_values).logits[0]
    # Get predict IDs
    pred_ids = torch.argmax(logits, axis=-1)

    # retrieve word stamps (analogous commands for `output_char_offsets`)
    outputs = tokenizer.decode(pred_ids, output_word_offsets=True)
    # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
    time_offset = token_model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate

    word_offsets = [
        {
            "word": d["word"],
            "start_time": round(d["start_offset"] * time_offset, 2),
            "end_time": round(d["end_offset"] * time_offset, 2),
        }
        for d in outputs.word_offsets
    ]
    return word_offsets