hobs commited on
Commit
cc9cfea
1 Parent(s): ca70154

load state_dict

Browse files
Files changed (1) hide show
  1. app.py +232 -1
app.py CHANGED
@@ -2,9 +2,240 @@
2
 
3
  import gradio as gr
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def greet_nationality(name):
7
- nationality = 'somewhere'
8
  return f"Hello {name}!!\n Your name seems to be from {nationality}. Am I right?"
9
 
10
 
 
2
 
3
  import gradio as gr
4
 
5
+ import os
6
+ from pathlib import Path
7
+ # import random
8
+ # import time
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ import pandas as pd
13
+ from nlpia2.init import SRC_DATA_DIR, maybe_download
14
+
15
+ from nlpia2.string_normalizers import Asciifier, ASCII_NAME_CHARS
16
+
17
+ name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 # Plus EOS marker
18
+
19
+ # Transcode Unicode str ASCII without embelishments, diacritics (https://stackoverflow.com/a/518232/2809427)
20
+ asciify = Asciifier(include=ASCII_NAME_CHARS)
21
+
22
+
23
+ def find_files(path, pattern):
24
+ return Path(path).glob(pattern)
25
+
26
+
27
+ # all_letters = ''.join(set(ASCII_NAME_CHARS).union(set(" .,;'")))
28
+ char2i = {c: i for i, c in enumerate(ASCII_NAME_CHARS)}
29
+
30
+ # !curl -O https://download.pytorch.org/tutorial/data.zip; unzip data.zip
31
+
32
+ print(f'asciify("O’Néàl") => {asciify("O’Néàl")}')
33
+
34
+ # Build the category_lines dictionary, a list of names per language
35
+ category_lines = {}
36
+ all_categories = []
37
+ labeled_lines = []
38
+ categories = []
39
+ for filepath in find_files(SRC_DATA_DIR / 'names', '*.txt'):
40
+ filename = Path(filepath).name
41
+ filepath = maybe_download(filename=Path('names') / filename)
42
+ with filepath.open() as fin:
43
+ lines = [asciify(line.rstrip()) for line in fin]
44
+ category = Path(filename).with_suffix('')
45
+ categories.append(category)
46
+ labeled_lines += list(zip(lines, [category] * len(lines)))
47
+
48
+ n_categories = len(categories)
49
+
50
+ df = pd.DataFrame(labeled_lines, columns=('name', 'category'))
51
+
52
+
53
+ def readLines(filename):
54
+ lines = open(filename, encoding='utf-8').read().strip().split('\n')
55
+ return [asciify(line) for line in lines]
56
+
57
+
58
+ for filename in find_files(path='data/names', pattern='*.txt'):
59
+ category = os.path.splitext(os.path.basename(filename))[0]
60
+ all_categories.append(category)
61
+ lines = readLines(filename)
62
+ category_lines[category] = lines
63
+
64
+ n_categories = len(all_categories)
65
+
66
+
67
+ ######################################################################
68
+ # Now we have ``category_lines``, a dictionary mapping each category
69
+ # (language) to a list of lines (names). We also kept track of
70
+ # ``all_categories`` (just a list of languages) and ``n_categories`` for
71
+ # later reference.
72
+ #
73
+
74
+ print(category_lines['Italian'][:5])
75
+
76
+
77
+ ######################################################################
78
+ # Turning Names into Tensors
79
+ # --------------------------
80
+ #
81
+ # Now that we have all the names organized, we need to turn them into
82
+ # Tensors to make any use of them.
83
+ #
84
+ # To represent a single letter, we use a "one-hot vector" of size
85
+ # ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
86
+ # at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
87
+ #
88
+ # To make a word we join a bunch of those into a 2D matrix
89
+ # ``<line_length x 1 x n_letters>``.
90
+ #
91
+ # That extra 1 dimension is because PyTorch assumes everything is in
92
+ # batches - we're just using a batch size of 1 here.
93
+ #
94
+
95
+ # Find letter index from all_letters, e.g. "a" = 0
96
+
97
+
98
+ def letterToIndex(c):
99
+ return char2i[c]
100
+
101
+ # Just for demonstration, turn a letter into a <1 x n_letters> Tensor
102
+
103
+
104
+ def encode_one_hot_vec(letter):
105
+ tensor = torch.zeros(1, len(ASCII_NAME_CHARS))
106
+ tensor[0][letterToIndex(letter)] = 1
107
+ return tensor
108
+
109
+ # Turn a line into a <line_length x 1 x n_letters>,
110
+ # or an array of one-hot letter vectors
111
+
112
+
113
+ def encode_one_hot_seq(line):
114
+ tensor = torch.zeros(len(line), 1, len(ASCII_NAME_CHARS))
115
+ for li, letter in enumerate(line):
116
+ tensor[li][0][letterToIndex(letter)] = 1
117
+ return tensor
118
+
119
+
120
+ print(encode_one_hot_vec('A'))
121
+
122
+ print(encode_one_hot_seq('Abe').size())
123
+
124
+
125
+ ######################################################################
126
+ # Creating the Network
127
+ # ====================
128
+ #
129
+ # Before autograd, creating a recurrent neural network in Torch involved
130
+ # cloning the parameters of a layer over several timesteps. The layers
131
+ # held hidden state and gradients which are now entirely handled by the
132
+ # graph itself. This means you can implement a RNN in a very "pure" way,
133
+ # as regular feed-forward layers.
134
+ #
135
+ # This RNN module (mostly copied from `the PyTorch for Torch users
136
+ # tutorial <https://pytorch.org/tutorials/beginner/former_torchies/
137
+ # nn_tutorial.html#example-2-recurrent-net>`__)
138
+ # is just 2 linear layers which operate on an input and hidden state, with
139
+ # a LogSoftmax layer after the output.
140
+ #
141
+ # .. figure:: https://i.imgur.com/Z2xbySO.png
142
+ # :alt:
143
+ #
144
+ #
145
+
146
+
147
+ class RNN(nn.Module):
148
+ def __init__(self, input_size, hidden_size, output_size):
149
+ super(RNN, self).__init__()
150
+
151
+ self.hidden_size = hidden_size
152
+
153
+ self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
154
+ self.i2o = nn.Linear(input_size + hidden_size, output_size)
155
+ self.softmax = nn.LogSoftmax(dim=1)
156
+
157
+ def forward(self, char_tens, hidden):
158
+ combined = torch.cat((char_tens, hidden), 1)
159
+ hidden = self.i2h(combined)
160
+ output = self.i2o(combined)
161
+ output = self.softmax(output)
162
+ return output, hidden
163
+
164
+ def initHidden(self):
165
+ return torch.zeros(1, self.hidden_size)
166
+
167
+
168
+ n_hidden = 128
169
+ rnn = RNN(len(ASCII_NAME_CHARS), n_hidden, n_categories)
170
+
171
+
172
+ input = encode_one_hot_vec('A')
173
+ hidden = torch.zeros(1, n_hidden)
174
+
175
+ output, next_hidden = rnn(input, hidden)
176
+
177
+
178
+ def categoryFromOutput(output):
179
+ top_n, top_i = output.topk(1)
180
+ category_i = top_i[0].item()
181
+ return all_categories[category_i], category_i
182
+
183
+
184
+ def output_from_str(s):
185
+ global rnn
186
+
187
+ input = encode_one_hot_seq(s)
188
+ hidden = torch.zeros(1, n_hidden)
189
+
190
+ output, next_hidden = rnn(input[0], hidden)
191
+ print(output)
192
+
193
+ return categoryFromOutput(output)
194
+
195
+
196
+ ########################################
197
+ # load/save test for use on the huggingface spaces server
198
+ # torch.save(rnn.state_dict(), 'rnn_from_scratch_name_nationality.state_dict.pickle')
199
+
200
+ state_dict = torch.load('rnn_from_scratch_name_nationality.state_dict.pickle')
201
+ rnn.load_state_dict(state_dict)
202
+
203
+
204
+ def evaluate(line_tensor):
205
+ hidden = rnn.initHidden()
206
+
207
+ for i in range(line_tensor.size()[0]):
208
+ output, hidden = rnn(line_tensor[i], hidden)
209
+
210
+ return output
211
+
212
+
213
+ def predict(input_line, n_predictions=3):
214
+ print('\n> %s' % input_line)
215
+ with torch.no_grad():
216
+ output = evaluate(encode_one_hot_seq(input_line))
217
+
218
+ # Get top N categories
219
+ topv, topi = output.topk(n_predictions, 1, True)
220
+ predictions = []
221
+
222
+ for i in range(n_predictions):
223
+ value = topv[0][i].item()
224
+ category_index = topi[0][i].item()
225
+ print('(%.2f) %s' % (value, all_categories[category_index]))
226
+ predictions.append([value, all_categories[category_index]])
227
+
228
+
229
+ predict('Dovesky')
230
+ predict('Jackson')
231
+ predict('Satoshi')
232
+
233
+ # load/save test for use on the huggingface spaces server
234
+ ########################################
235
+
236
 
237
  def greet_nationality(name):
238
+ nationality = predict(name)
239
  return f"Hello {name}!!\n Your name seems to be from {nationality}. Am I right?"
240
 
241