papasega commited on
Commit
4a7b642
1 Parent(s): b79b2d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -1,21 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- from modelutil import create_model
 
 
4
 
 
 
5
 
6
  def predict_digit(image):
7
- try:
8
- if image == None: pass
9
- except:
10
- model = create_model()
11
- predictions = model.predict(image.reshape(1, 28, 28))
12
- return np.argmax(predictions)
13
 
14
  gr.Interface(
15
  title="MNIST Digit Classifier by Papa Sega",
16
  fn=predict_digit,
17
- inputs=gr.Sketchpad( label="Draw a digit"),
18
  outputs="number",
19
  live=True
20
  ).launch()
21
-
 
1
+ # import gradio as gr
2
+ # import numpy as np
3
+ # from modelutil import create_model
4
+
5
+
6
+ # def predict_digit(image):
7
+ # try:
8
+ # if image == None: pass
9
+ # except:
10
+ # model = create_model()
11
+ # predictions = model.predict(image.reshape(1, 28, 28))
12
+ # return np.argmax(predictions)
13
+
14
+ # gr.Interface(
15
+ # title="MNIST Digit Classifier by Papa Sega",
16
+ # fn=predict_digit,
17
+ # inputs=gr.Sketchpad( label="Draw a digit"),
18
+ # outputs="number",
19
+ # live=True
20
+ # ).launch()
21
+
22
+ def load_model_weights(model, checkpoint_path):
23
+ model.load_weights(checkpoint_path)
24
+
25
  import gradio as gr
26
  import numpy as np
27
+ from modelutil import create_model, load_model_weights
28
+
29
+ checkpoint_path = './checkpoint'
30
 
31
+ model = create_model()
32
+ load_model_weights(model, checkpoint_path)
33
 
34
  def predict_digit(image):
35
+ predictions = model.predict(image.reshape(1, 28, 28))
36
+ return np.argmax(predictions)
 
 
 
 
37
 
38
  gr.Interface(
39
  title="MNIST Digit Classifier by Papa Sega",
40
  fn=predict_digit,
41
+ inputs=gr.Sketchpad(label="Draw a digit"),
42
  outputs="number",
43
  live=True
44
  ).launch()