hobs commited on
Commit
0647bb4
1 Parent(s): 5c11d69

load categories from json

Browse files
Files changed (2) hide show
  1. app.py +108 -74
  2. categories.json +1 -0
app.py CHANGED
@@ -2,17 +2,117 @@
2
 
3
  import gradio as gr
4
 
5
- import os
6
  from pathlib import Path
7
  # import random
8
  # import time
9
  import torch
10
  import torch.nn as nn
11
 
12
- import pandas as pd
13
- from nlpia2.init import SRC_DATA_DIR, maybe_download
14
 
15
- from nlpia2.string_normalizers import Asciifier, ASCII_NAME_CHARS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 # Plus EOS marker
18
 
@@ -31,49 +131,10 @@ char2i = {c: i for i, c in enumerate(ASCII_NAME_CHARS)}
31
 
32
  print(f'asciify("O’Néàl") => {asciify("O’Néàl")}')
33
 
34
- # Build the category_lines dictionary, a list of names per language
35
- category_lines = {}
36
- all_categories = []
37
- labeled_lines = []
38
- categories = []
39
- for filepath in find_files(SRC_DATA_DIR / 'names', '*.txt'):
40
- filename = Path(filepath).name
41
- filepath = maybe_download(filename=Path('names') / filename)
42
- with filepath.open() as fin:
43
- lines = [asciify(line.rstrip()) for line in fin]
44
- category = Path(filename).with_suffix('')
45
- categories.append(category)
46
- labeled_lines += list(zip(lines, [category] * len(lines)))
47
 
 
48
  n_categories = len(categories)
49
 
50
- df = pd.DataFrame(labeled_lines, columns=('name', 'category'))
51
-
52
-
53
- def readLines(filename):
54
- lines = open(filename, encoding='utf-8').read().strip().split('\n')
55
- return [asciify(line) for line in lines]
56
-
57
-
58
- for filename in find_files(path='data/names', pattern='*.txt'):
59
- category = os.path.splitext(os.path.basename(filename))[0]
60
- all_categories.append(category)
61
- lines = readLines(filename)
62
- category_lines[category] = lines
63
-
64
- n_categories = len(all_categories)
65
-
66
-
67
- ######################################################################
68
- # Now we have ``category_lines``, a dictionary mapping each category
69
- # (language) to a list of lines (names). We also kept track of
70
- # ``all_categories`` (just a list of languages) and ``n_categories`` for
71
- # later reference.
72
- #
73
-
74
- print(category_lines['Italian'][:5])
75
-
76
-
77
  ######################################################################
78
  # Turning Names into Tensors
79
  # --------------------------
@@ -117,33 +178,6 @@ def encode_one_hot_seq(line):
117
  return tensor
118
 
119
 
120
- print(encode_one_hot_vec('A'))
121
-
122
- print(encode_one_hot_seq('Abe').size())
123
-
124
-
125
- ######################################################################
126
- # Creating the Network
127
- # ====================
128
- #
129
- # Before autograd, creating a recurrent neural network in Torch involved
130
- # cloning the parameters of a layer over several timesteps. The layers
131
- # held hidden state and gradients which are now entirely handled by the
132
- # graph itself. This means you can implement a RNN in a very "pure" way,
133
- # as regular feed-forward layers.
134
- #
135
- # This RNN module (mostly copied from `the PyTorch for Torch users
136
- # tutorial <https://pytorch.org/tutorials/beginner/former_torchies/
137
- # nn_tutorial.html#example-2-recurrent-net>`__)
138
- # is just 2 linear layers which operate on an input and hidden state, with
139
- # a LogSoftmax layer after the output.
140
- #
141
- # .. figure:: https://i.imgur.com/Z2xbySO.png
142
- # :alt:
143
- #
144
- #
145
-
146
-
147
  class RNN(nn.Module):
148
  def __init__(self, input_size, hidden_size, output_size):
149
  super(RNN, self).__init__()
@@ -178,7 +212,7 @@ output, next_hidden = rnn(input, hidden)
178
  def categoryFromOutput(output):
179
  top_n, top_i = output.topk(1)
180
  category_i = top_i[0].item()
181
- return all_categories[category_i], category_i
182
 
183
 
184
  def output_from_str(s):
@@ -222,8 +256,8 @@ def predict(input_line, n_predictions=3):
222
  for i in range(n_predictions):
223
  value = topv[0][i].item()
224
  category_index = topi[0][i].item()
225
- print('(%.2f) %s' % (value, all_categories[category_index]))
226
- predictions.append([value, all_categories[category_index]])
227
 
228
 
229
  predict('Dovesky')
 
2
 
3
  import gradio as gr
4
 
5
+ import json
6
  from pathlib import Path
7
  # import random
8
  # import time
9
  import torch
10
  import torch.nn as nn
11
 
 
 
12
 
13
+ import string
14
+ import unicodedata
15
+ from unidecode import unidecode
16
+
17
+
18
+ ASCII_LETTERS = string.ascii_letters
19
+ ASCII_PRINTABLE = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'
20
+ ASCII_PRINTABLE_COMMON = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r'
21
+
22
+ ASCII_VERTICAL_TAB = '\x0b'
23
+ ASCII_PAGE_BREAK = '\x0c'
24
+ ASCII_ALL = ''.join(chr(i) for i in range(0, 128)) # ASCII_PRINTABLE
25
+ ASCII_DIGITS = string.digits
26
+ ASCII_IMPORTANT_PUNCTUATION = " .?!,;'-=+)(:"
27
+ ASCII_NAME_PUNCTUATION = " .,;'-"
28
+ ASCII_NAME_CHARS = set(ASCII_LETTERS + ASCII_NAME_PUNCTUATION)
29
+ ASCII_IMPORTANT_CHARS = set(ASCII_LETTERS + ASCII_IMPORTANT_PUNCTUATION)
30
+
31
+ CURLY_SINGLE_QUOTES = '‘’`´'
32
+ STRAIGHT_SINGLE_QUOTES = "'" * len(CURLY_SINGLE_QUOTES)
33
+ CURLY_DOUBLE_QUOTES = '“”'
34
+ STRAIGHT_DOUBLE_QUOTES = '"' * len(CURLY_DOUBLE_QUOTES)
35
+
36
+
37
+ def normalize_newlines(s):
38
+ s = s.replace(ASCII_VERTICAL_TAB, '\n')
39
+ s = s.replace(ASCII_PAGE_BREAK, '\n\n')
40
+
41
+
42
+ class Asciifier:
43
+ """ Construct a function that filters out all non-ascii unicode characters
44
+
45
+ >>> test_str = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'
46
+ >>> Asciifier(include='a b c 123XYZ')(test_str):
47
+ '123abcXYZ '
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ min_ord=1, max_ord=128,
53
+ exclude=None,
54
+ include=ASCII_PRINTABLE,
55
+ exclude_category='Mn',
56
+ normalize_quotes=True,
57
+ ):
58
+ self.include = set(sorted(include or ASCII_PRINTABLE))
59
+ self._include = ''.join(sorted(self.include))
60
+ self.exclude = exclude or set()
61
+ self.exclude = set(sorted(exclude or []))
62
+ self._exclude = ''.join(self.exclude)
63
+ self.min_ord, self.max_ord = int(min_ord), int(max_ord or 128)
64
+ self.normalize_quotes = normalize_quotes
65
+
66
+ if self.min_ord:
67
+ self.include = set(c for c in self.include if ord(c) >= self.min_ord)
68
+ if self.max_ord:
69
+ self.include = set(c for c in self._include if ord(c) <= self.max_ord)
70
+ if exclude_category:
71
+ self.include = set(
72
+ c for c in self._include if unicodedata.category(c) != exclude_category)
73
+
74
+ self.vocab = sorted(self.include - self.exclude)
75
+ self._vocab = ''.join(self.vocab)
76
+ self.char2i = {c: i for (i, c) in enumerate(self._vocab)}
77
+
78
+ self._translate_from = self._vocab
79
+ self._translate_to = self._translate_from
80
+
81
+ # FIXME: self.normalize_quotes is accomplished by unidecode.unidecode!!
82
+ # ’->' ‘->' “->" ”->"
83
+ if self.normalize_quotes:
84
+ trans_table = str.maketrans(
85
+ CURLY_SINGLE_QUOTES + CURLY_DOUBLE_QUOTES,
86
+ STRAIGHT_SINGLE_QUOTES + STRAIGHT_DOUBLE_QUOTES)
87
+ self._translate_to = self._translate_to.translate(trans_table)
88
+ # print(self._translate_to)
89
+
90
+ # eliminate any non-translations (if from == to)
91
+ self._translate_from_filtered = ''
92
+ self._translate_to_filtered = ''
93
+
94
+ for c1, c2 in zip(self._translate_from, self._translate_to):
95
+ if c1 == c2:
96
+ continue
97
+ else:
98
+ self._translate_from_filtered += c1
99
+ self._translate_to_filtered += c2
100
+
101
+ self._translate_del = ''
102
+ for c in ASCII_ALL:
103
+ if c not in self.vocab:
104
+ self._translate_del += c
105
+
106
+ self._translate_from = self._translate_from_filtered
107
+ self._translate_to = self._translate_to_filtered
108
+ self.translation_table = str.maketrans(
109
+ self._translate_from,
110
+ self._translate_to,
111
+ self._translate_del)
112
+
113
+ def __call__(self, text):
114
+ return unidecode(unicodedata.normalize('NFD', text)).translate(self.translation_table)
115
+
116
 
117
  name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 # Plus EOS marker
118
 
 
131
 
132
  print(f'asciify("O’Néàl") => {asciify("O’Néàl")}')
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ categories = json.load(open('categories.json'))
136
  n_categories = len(categories)
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ######################################################################
139
  # Turning Names into Tensors
140
  # --------------------------
 
178
  return tensor
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  class RNN(nn.Module):
182
  def __init__(self, input_size, hidden_size, output_size):
183
  super(RNN, self).__init__()
 
212
  def categoryFromOutput(output):
213
  top_n, top_i = output.topk(1)
214
  category_i = top_i[0].item()
215
+ return categories[category_i], category_i
216
 
217
 
218
  def output_from_str(s):
 
256
  for i in range(n_predictions):
257
  value = topv[0][i].item()
258
  category_index = topi[0][i].item()
259
+ print('(%.2f) %s' % (value, categories[category_index]))
260
+ predictions.append([value, categories[category_index]])
261
 
262
 
263
  predict('Dovesky')
categories.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["Arabic", "Irish", "Spanish", "French", "German", "English", "Korean", "Vietnamese", "Scottish", "Japanese", "Polish", "Greek", "Czech", "Italian", "Portuguese", "Russian", "Dutch", "Chinese"]