Vineel Pratap
resampling fix
f138a14
# Creates unigram LM following KenLM
import math
import shutil, tempfile
def calculate_log_probabilities(word_counts, num_sentences, n_smoothing=0.01):
"""
Calculate log probabilities for each word in the corpus,
including a special <unk> token for unknown words.
"""
total_words = sum(word_counts.values())
total_words += 2 * num_sentences # add counts for <s> and </s>
# Adjust total for <unk>
total_words_with_unk = total_words + 1 # Adding 1 for <unk>
total_words_with_unk = total_words_with_unk + total_words_with_unk * n_smoothing
# Calculate probabilities, adjust for <unk>
probabilities = {
word: ((count + n_smoothing) / total_words_with_unk)
for word, count in word_counts.items()
}
probabilities["<unk>"] = 1 / total_words_with_unk
probabilities["<s>"] = (num_sentences + n_smoothing) / total_words_with_unk
probabilities["</s>"] = (num_sentences + n_smoothing) / total_words_with_unk
# Convert to log probabilities
return {word: math.log10(prob) for word, prob in probabilities.items()}
def maybe_generate_pseudo_bigram_arpa(arpa_fpath):
with open(arpa_fpath, "r") as file:
lines = file.readlines()
# if ngram order >=2 , do not modify
if any(["2-grams:" in l for l in lines]):
return
with open(arpa_fpath, "w") as file:
for line in lines:
if line.strip().startswith("ngram 1="):
file.write(line)
file.write("ngram 2=1\n") # Add the new ngram line
continue
if line.strip() == "\\end\\":
file.write("\\2-grams:\n")
file.write("-9.9999999\t</s> <s>\n\n")
file.write(line)
def save_log_probabilities(log_probabilities, file_path):
with open(file_path, "w") as file:
file.write(f"\data\\")
file.write(f"\n")
file.write(f"ngram 1={len(log_probabilities)}\n\n")
file.write(f"\\1-grams:")
file.write(f"\n")
for word, log_prob in log_probabilities.items():
if word == "<s>":
log_prob = 0
file.write(f"{log_prob}\t{word}\n")
file.write(f"\n")
file.write(f"\end\\")
def create_unigram_lm(word_counts, num_sentences, file_path, n_smoothing=0.01):
log_probs = calculate_log_probabilities(word_counts, num_sentences, n_smoothing)
save_log_probabilities(log_probs, file_path)