nb-wav2vec2-kenlm / prepare.py
versae's picture
Adding 5gram models
3ddfd5c
#!/usr/bin/env python
# coding=utf-8
import re
from tqdm import tqdm
from datasets import load_dataset, interleave_datasets, concatenate_datasets
TEXT_COLUMN_NAME = "text"
AUDIO_COLUMN_NAME = "audio"
CHARS_TO_IGNORE_REGEX = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/0-9]'
# Pre-processing dataset
def replace_hatted_characters(batch):
text = batch["text"]
text = re.sub(CHARS_TO_IGNORE_REGEX, '', text).lower() + ' '
text = re.sub('[áàâ]', 'a', text)
text = re.sub('[ä]', 'æ', text)
text = re.sub('[éèëê]', 'e', text)
text = re.sub('[íìïî]', 'i', text)
text = re.sub('[óòöô]', 'o', text)
text = re.sub('[ö]', 'ø', text)
text = re.sub('[ç]', 'c', text)
text = re.sub('[úùüû]', 'u', text)
text = re.sub('\xa0', ' ', text)
text = re.sub('<ee>', 'eee', text)
text = re.sub('<qq>', 'qqq', text)
text = re.sub('<mm>', 'mmm', text)
text = re.sub('<inaudible>', 'xxx', text)
text = re.sub('[<>]', '', text)
text = re.sub(r'\s+', ' ', text)
return {"text": text}
def main():
npsc = load_dataset(
"NbAiLab/NPSC",
"16K_mp3",
split="train+validation",
use_auth_token=True,
)
ncc = load_dataset(
"NbAiLab/NCC",
split="train+validation",
use_auth_token=True
)
dataset = concatenate_datasets([npsc, ncc])
dataset = dataset.map(
replace_hatted_characters,
desc="replacing hesitations and homophones",
)
# Create file with all text together
text_count = len(dataset)
with open("text.txt", "w") as text_file:
for idx, text in tqdm(enumerate(dataset["text"]), total=text_count, desc="Writing text"):
if idx == text_count:
text_file.write(text)
else:
text_file.write(text + " ")
# Create KenLM model
!~/bin/lmplz -o 5 --text text.txt --arpa 5gram.arpa.orig -T $(pwd)
# Adjusting for Huggingface decoding
with open("5gram.arpa.orig", "r") as read_file, open("5gram.arpa", "w") as write_file:
has_added_eos = False
for line in read_file:
if not has_added_eos and "ngram 1=" in line:
count=line.strip().split("=")[-1]
write_file.write(line.replace(f"{count}", f"{int(count)+1}"))
elif not has_added_eos and "<s>" in line:
write_file.write(line)
write_file.write(line.replace("<s>", "</s>"))
has_added_eos = True
else:
write_file.write(line)
# Compress as binary
!~/bin/build_binary 5gram.arpa 5gram.bin -T $(pwd)
!rm 5gram.arpa*
if __name__ == "__main__":
main()