|
|
|
|
|
import gradio as gr |
|
|
|
import json |
|
from pathlib import Path |
|
|
|
|
|
import string |
|
|
|
import torch |
|
import torch.nn as nn |
|
import unicodedata |
|
from unidecode import unidecode |
|
|
|
|
|
ASCII_LETTERS = string.ascii_letters |
|
ASCII_PRINTABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c' |
|
ASCII_PRINTABLE_COMMON = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r' |
|
|
|
ASCII_VERTICAL_TAB = '\x0b' |
|
ASCII_PAGE_BREAK = '\x0c' |
|
ASCII_ALL = ''.join(chr(i) for i in range(0, 128)) |
|
ASCII_DIGITS = string.digits |
|
ASCII_IMPORTANT_PUNCTUATION = " .?!,;'-=+)(:" |
|
ASCII_NAME_PUNCTUATION = " .,;'-" |
|
ASCII_NAME_CHARS = set(ASCII_LETTERS + ASCII_NAME_PUNCTUATION) |
|
ASCII_IMPORTANT_CHARS = set(ASCII_LETTERS + ASCII_IMPORTANT_PUNCTUATION) |
|
|
|
CURLY_SINGLE_QUOTES = '‘’`´' |
|
STRAIGHT_SINGLE_QUOTES = "'" * len(CURLY_SINGLE_QUOTES) |
|
CURLY_DOUBLE_QUOTES = '“”' |
|
STRAIGHT_DOUBLE_QUOTES = '"' * len(CURLY_DOUBLE_QUOTES) |
|
|
|
|
|
def normalize_newlines(s): |
|
s = s.replace(ASCII_VERTICAL_TAB, '\n') |
|
s = s.replace(ASCII_PAGE_BREAK, '\n\n') |
|
|
|
|
|
class Asciifier: |
|
""" Construct a function that filters out all non-ascii unicode characters |
|
|
|
>>> test_str = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c' |
|
>>> Asciifier(include='a b c 123XYZ')(test_str): |
|
'123abcXYZ ' |
|
""" |
|
|
|
def __init__( |
|
self, |
|
min_ord=1, max_ord=128, |
|
exclude=None, |
|
include=ASCII_PRINTABLE, |
|
exclude_category='Mn', |
|
normalize_quotes=True, |
|
): |
|
self.include = set(sorted(include or ASCII_PRINTABLE)) |
|
self._include = ''.join(sorted(self.include)) |
|
self.exclude = exclude or set() |
|
self.exclude = set(sorted(exclude or [])) |
|
self._exclude = ''.join(self.exclude) |
|
self.min_ord, self.max_ord = int(min_ord), int(max_ord or 128) |
|
self.normalize_quotes = normalize_quotes |
|
|
|
if self.min_ord: |
|
self.include = set(c for c in self.include if ord(c) >= self.min_ord) |
|
if self.max_ord: |
|
self.include = set(c for c in self._include if ord(c) <= self.max_ord) |
|
if exclude_category: |
|
self.include = set( |
|
c for c in self._include if unicodedata.category(c) != exclude_category) |
|
|
|
self.vocab = sorted(self.include - self.exclude) |
|
self._vocab = ''.join(self.vocab) |
|
self.char2i = {c: i for (i, c) in enumerate(self._vocab)} |
|
|
|
self._translate_from = self._vocab |
|
self._translate_to = self._translate_from |
|
|
|
|
|
|
|
if self.normalize_quotes: |
|
trans_table = str.maketrans( |
|
CURLY_SINGLE_QUOTES + CURLY_DOUBLE_QUOTES, |
|
STRAIGHT_SINGLE_QUOTES + STRAIGHT_DOUBLE_QUOTES) |
|
self._translate_to = self._translate_to.translate(trans_table) |
|
|
|
|
|
|
|
self._translate_from_filtered = '' |
|
self._translate_to_filtered = '' |
|
|
|
for c1, c2 in zip(self._translate_from, self._translate_to): |
|
if c1 == c2: |
|
continue |
|
else: |
|
self._translate_from_filtered += c1 |
|
self._translate_to_filtered += c2 |
|
|
|
self._translate_del = '' |
|
for c in ASCII_ALL: |
|
if c not in self.vocab: |
|
self._translate_del += c |
|
|
|
self._translate_from = self._translate_from_filtered |
|
self._translate_to = self._translate_to_filtered |
|
self.translation_table = str.maketrans( |
|
self._translate_from, |
|
self._translate_to, |
|
self._translate_del) |
|
|
|
def __call__(self, text): |
|
return unidecode(unicodedata.normalize('NFD', text)).translate(self.translation_table) |
|
|
|
|
|
name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 |
|
|
|
|
|
asciify = Asciifier(include=ASCII_NAME_CHARS) |
|
|
|
|
|
def find_files(path, pattern): |
|
return Path(path).glob(pattern) |
|
|
|
|
|
|
|
char2i = {c: i for i, c in enumerate(ASCII_NAME_CHARS)} |
|
|
|
|
|
|
|
print(f'asciify("O’Néàl") => {asciify("O’Néàl")}') |
|
|
|
|
|
categories = json.load(open('categories.json')) |
|
n_categories = len(categories) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def letterToIndex(c): |
|
return char2i[c] |
|
|
|
|
|
|
|
|
|
def encode_one_hot_vec(letter): |
|
tensor = torch.zeros(1, len(ASCII_NAME_CHARS)) |
|
tensor[0][letterToIndex(letter)] = 1 |
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
def encode_one_hot_seq(line): |
|
tensor = torch.zeros(len(line), 1, len(ASCII_NAME_CHARS)) |
|
for li, letter in enumerate(line): |
|
tensor[li][0][letterToIndex(letter)] = 1 |
|
return tensor |
|
|
|
|
|
class RNN(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size): |
|
super(RNN, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.i2h = nn.Linear(input_size + hidden_size, hidden_size) |
|
self.i2o = nn.Linear(input_size + hidden_size, output_size) |
|
self.softmax = nn.LogSoftmax(dim=1) |
|
|
|
def forward(self, char_tens, hidden): |
|
combined = torch.cat((char_tens, hidden), 1) |
|
hidden = self.i2h(combined) |
|
output = self.i2o(combined) |
|
output = self.softmax(output) |
|
return output, hidden |
|
|
|
def initHidden(self): |
|
return torch.zeros(1, self.hidden_size) |
|
|
|
|
|
n_hidden = 128 |
|
rnn = RNN(len(ASCII_NAME_CHARS), n_hidden, n_categories) |
|
|
|
|
|
input = encode_one_hot_vec('A') |
|
hidden = torch.zeros(1, n_hidden) |
|
|
|
output, next_hidden = rnn(input, hidden) |
|
|
|
|
|
def categoryFromOutput(output): |
|
top_n, top_i = output.topk(1) |
|
category_i = top_i[0].item() |
|
return categories[category_i], category_i |
|
|
|
|
|
def output_from_str(s): |
|
global rnn |
|
|
|
input = encode_one_hot_seq(s) |
|
hidden = torch.zeros(1, n_hidden) |
|
|
|
output, next_hidden = rnn(input[0], hidden) |
|
print(output) |
|
|
|
return categoryFromOutput(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
state_dict = torch.load('rnn_from_scratch_name_nationality.state_dict.pickle') |
|
rnn.load_state_dict(state_dict) |
|
|
|
|
|
def evaluate(line_tensor): |
|
hidden = rnn.initHidden() |
|
|
|
for i in range(line_tensor.size()[0]): |
|
output, hidden = rnn(line_tensor[i], hidden) |
|
|
|
return output |
|
|
|
|
|
def predict(input_line, n_predictions=3): |
|
print('\n> %s' % input_line) |
|
with torch.no_grad(): |
|
output = evaluate(encode_one_hot_seq(input_line)) |
|
|
|
|
|
topv, topi = output.topk(n_predictions, 1, True) |
|
predictions = [] |
|
|
|
for i in range(n_predictions): |
|
value = topv[0][i].item() |
|
category_index = topi[0][i].item() |
|
print('(%.2f) %s' % (value, categories[category_index])) |
|
predictions.append([value, categories[category_index]]) |
|
|
|
|
|
predict('Dovesky') |
|
predict('Jackson') |
|
predict('Satoshi') |
|
|
|
|
|
|
|
|
|
|
|
def greet_nationality(name): |
|
nationality = predict(name) |
|
return f"Hello {name}!!\n Your name seems to be from {nationality}. Am I right?" |
|
|
|
|
|
iface = gr.Interface(fn=greet_nationality, inputs="text", outputs="text") |
|
iface.launch() |
|
|