papasega's picture
Update app.py
4a7b642 verified
raw
history blame
1.09 kB
# import gradio as gr
# import numpy as np
# from modelutil import create_model
# def predict_digit(image):
# try:
# if image == None: pass
# except:
# model = create_model()
# predictions = model.predict(image.reshape(1, 28, 28))
# return np.argmax(predictions)
# gr.Interface(
# title="MNIST Digit Classifier by Papa Sega",
# fn=predict_digit,
# inputs=gr.Sketchpad( label="Draw a digit"),
# outputs="number",
# live=True
# ).launch()
def load_model_weights(model, checkpoint_path):
model.load_weights(checkpoint_path)
import gradio as gr
import numpy as np
from modelutil import create_model, load_model_weights
checkpoint_path = './checkpoint'
model = create_model()
load_model_weights(model, checkpoint_path)
def predict_digit(image):
predictions = model.predict(image.reshape(1, 28, 28))
return np.argmax(predictions)
gr.Interface(
title="MNIST Digit Classifier by Papa Sega",
fn=predict_digit,
inputs=gr.Sketchpad(label="Draw a digit"),
outputs="number",
live=True
).launch()