spaces-demo / app.py
Sagar Thacker
Update app.py
174e9bf
import pickle
import random
import pandas as pd
import gradio as gr
from fastai.vision.all import *
zone_lookup = pd.read_csv('./data/zone_lookup.csv')
with open('./models/lin_reg.bin', 'rb') as handle:
dv, model = pickle.load(handle)
def prepare_features(pickup, dropoff, trip_distance):
pickupId = zone_lookup[zone_lookup["borough_zone"] == pickup].LocationID
dropoffId = zone_lookup[zone_lookup["borough_zone"] == dropoff].LocationID
trip_distance = round(trip_distance, 4)
features = {}
features['PU_DO'] = '%s_%s' % (pickupId, dropoffId)
features['trip_distance'] = trip_distance
return features
def predict(pickup, dropoff, trip_distance):
features = prepare_features(pickup, dropoff, trip_distance)
X = dv.transform(features)
preds = model.predict(X)
duration = float(preds[0])
return "The predicted duration is %.4f minutes." % duration
with gr.Blocks() as demo:
gr.Markdown("""
This demo is a simple example of how to use Gradio to create a web interface for your machine learning models.
Models used in this demo are very simple and are not meant to perform well. The goal is to show how to use Gradio with a simple model.
""")
gr.Markdown("Predict Taxi Duration or Classify dog vs cat using this demo")
with gr.Tab("Predict Taxi Duration"):
with gr.Row():
with gr.Column():
with gr.Row():
pickup = gr.Dropdown(
choices=list(zone_lookup["borough_zone"]),
label="Pickup Location",
info="The location where the passenger(s) were picked up",
value=lambda: random.choice(zone_lookup["borough_zone"])
)
dropoff = gr.Dropdown(
choices=list(zone_lookup["borough_zone"]),
label="Dropoff Location",
info="The location where the passenger(s) were dropped off",
value=lambda: random.choice(zone_lookup["borough_zone"])
)
trip_distance = gr.Slider(
minimum=0.0,
maximum=100.0,
step=0.1,
label="Trip Distance",
info="The trip distance in miles calculated by the taximeter",
value=lambda: random.uniform(0.0, 100.0)
)
with gr.Column():
output = gr.Textbox(label="Output Box")
predict_btn = gr.Button("Predict")
examples = gr.Examples([["Queens - Bellerose", "Bronx - Schuylerville/Edgewater Park", 25], ["Bronx - Norwood", "rooklyn - Sunset Park West", 55]], inputs=[pickup, dropoff, trip_distance])
with gr.Tab("Classify Dog vs Cat"):
def is_cat(x): return x[0].isupper()
learn = load_learner('./models/model.pkl')
categories = ('Dog', 'Cat')
def classify_image(img):
pred, idx, probs = learn.predict(img)
return dict(zip(categories, map(float,probs)))
with gr.Row():
image = gr.inputs.Image(shape=(192, 192))
label = gr.outputs.Label()
examples = gr.Examples(['./examples/dog.jpg', './examples/cat.jpg', './examples/dunno.jpg', './examples/basset.jpg'], inputs=[image])
classify_btn = gr.Button("Predict")
predict_btn.click(fn=predict, inputs=[pickup, dropoff, trip_distance], outputs=output, api_name="predict-duration")
classify_btn.click(fn=classify_image, inputs=image, outputs=label, api_name="classify-dog-breed")
demo.launch()