Spaces:
Running
Running
import pickle | |
import random | |
import pandas as pd | |
import gradio as gr | |
from fastai.vision.all import * | |
zone_lookup = pd.read_csv('./data/zone_lookup.csv') | |
with open('./models/lin_reg.bin', 'rb') as handle: | |
dv, model = pickle.load(handle) | |
def prepare_features(pickup, dropoff, trip_distance): | |
pickupId = zone_lookup[zone_lookup["borough_zone"] == pickup].LocationID | |
dropoffId = zone_lookup[zone_lookup["borough_zone"] == dropoff].LocationID | |
trip_distance = round(trip_distance, 4) | |
features = {} | |
features['PU_DO'] = '%s_%s' % (pickupId, dropoffId) | |
features['trip_distance'] = trip_distance | |
return features | |
def predict(pickup, dropoff, trip_distance): | |
features = prepare_features(pickup, dropoff, trip_distance) | |
X = dv.transform(features) | |
preds = model.predict(X) | |
duration = float(preds[0]) | |
return "The predicted duration is %.4f minutes." % duration | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
This demo is a simple example of how to use Gradio to create a web interface for your machine learning models. | |
Models used in this demo are very simple and are not meant to perform well. The goal is to show how to use Gradio with a simple model. | |
""") | |
gr.Markdown("Predict Taxi Duration or Classify dog vs cat using this demo") | |
with gr.Tab("Predict Taxi Duration"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
pickup = gr.Dropdown( | |
choices=list(zone_lookup["borough_zone"]), | |
label="Pickup Location", | |
info="The location where the passenger(s) were picked up", | |
value=lambda: random.choice(zone_lookup["borough_zone"]) | |
) | |
dropoff = gr.Dropdown( | |
choices=list(zone_lookup["borough_zone"]), | |
label="Dropoff Location", | |
info="The location where the passenger(s) were dropped off", | |
value=lambda: random.choice(zone_lookup["borough_zone"]) | |
) | |
trip_distance = gr.Slider( | |
minimum=0.0, | |
maximum=100.0, | |
step=0.1, | |
label="Trip Distance", | |
info="The trip distance in miles calculated by the taximeter", | |
value=lambda: random.uniform(0.0, 100.0) | |
) | |
with gr.Column(): | |
output = gr.Textbox(label="Output Box") | |
predict_btn = gr.Button("Predict") | |
examples = gr.Examples([["Queens - Bellerose", "Bronx - Schuylerville/Edgewater Park", 25], ["Bronx - Norwood", "rooklyn - Sunset Park West", 55]], inputs=[pickup, dropoff, trip_distance]) | |
with gr.Tab("Classify Dog vs Cat"): | |
def is_cat(x): return x[0].isupper() | |
learn = load_learner('./models/model.pkl') | |
categories = ('Dog', 'Cat') | |
def classify_image(img): | |
pred, idx, probs = learn.predict(img) | |
return dict(zip(categories, map(float,probs))) | |
with gr.Row(): | |
image = gr.inputs.Image(shape=(192, 192)) | |
label = gr.outputs.Label() | |
examples = gr.Examples(['./examples/dog.jpg', './examples/cat.jpg', './examples/dunno.jpg', './examples/basset.jpg'], inputs=[image]) | |
classify_btn = gr.Button("Predict") | |
predict_btn.click(fn=predict, inputs=[pickup, dropoff, trip_distance], outputs=output, api_name="predict-duration") | |
classify_btn.click(fn=classify_image, inputs=image, outputs=label, api_name="classify-dog-breed") | |
demo.launch() |