papasega commited on
Commit
7804fbe
1 Parent(s): 4a7b642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -19,22 +19,27 @@
19
  # live=True
20
  # ).launch()
21
 
22
- def load_model_weights(model, checkpoint_path):
23
- model.load_weights(checkpoint_path)
24
-
25
  import gradio as gr
26
  import numpy as np
27
- from modelutil import create_model, load_model_weights
28
-
29
- checkpoint_path = './checkpoint'
30
 
31
- model = create_model()
32
- load_model_weights(model, checkpoint_path)
33
 
34
  def predict_digit(image):
35
- predictions = model.predict(image.reshape(1, 28, 28))
36
- return np.argmax(predictions)
 
 
 
 
37
 
 
 
 
 
 
 
38
  gr.Interface(
39
  title="MNIST Digit Classifier by Papa Sega",
40
  fn=predict_digit,
@@ -42,3 +47,4 @@ gr.Interface(
42
  outputs="number",
43
  live=True
44
  ).launch()
 
 
19
  # live=True
20
  # ).launch()
21
 
22
+ import tensorflow as tf
 
 
23
  import gradio as gr
24
  import numpy as np
 
 
 
25
 
26
+ # Load the CNN model from the .h5 file
27
+ model = tf.keras.models.load_model('mnist_cnn_model.h5')
28
 
29
  def predict_digit(image):
30
+ # Preprocess the input image
31
+ image = np.expand_dims(image, axis=0) # Add batch dimension
32
+ image = image / 255.0 # Normalize pixel values
33
+
34
+ # Make predictions
35
+ predictions = model.predict(image)
36
 
37
+ # Get the predicted digit
38
+ predicted_digit = np.argmax(predictions)
39
+
40
+ return predicted_digit
41
+
42
+ # Define Gradio interface
43
  gr.Interface(
44
  title="MNIST Digit Classifier by Papa Sega",
45
  fn=predict_digit,
 
47
  outputs="number",
48
  live=True
49
  ).launch()
50
+