Spaces:
Runtime error
Runtime error
File size: 2,059 Bytes
5f37d56 278b1ca 5f37d56 123e402 5f37d56 9486004 5f37d56 9486004 5f37d56 9486004 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os
from timeit import default_timer as timer
from typing import Tuple
from pathlib import Path
from PIL import Image
import gradio as gr
import torch
from torch import nn
from torchvision import transforms
from model import create_effnetb2_model
class_names = ["pizza", "steak", "sushi"]
device = "cpu"
# Create model
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))
# Load saved weights
effnetb2.load_state_dict(torch.load("effnetb2.pth", map_location=torch.device(device)))
# Define predict function
def predict(img: Image) -> Tuple[dict, float]:
"""Uses EffnetB2 model to transform and predict on img. Returns prediction
probabilities and time taken.
Args:
img (PIL.Image): Image to predict on.
Returns:
A tuple (pred_labels_and_probs, pred_time), where pred_labels_and_probs
is a dict mapping each class name to the probability the model assigns to
it, and pred_time is the time taken to predict (in seconds).
"""
start_time = timer()
img = effnetb2_transforms(img).unsqueeze(0)
effnetb2.eval()
with torch.inference_mode():
pred_probs = torch.softmax(effnetb2(img), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i])
for i in range(len(class_names))}
pred_time = round(timer() - start_time, 4)
return pred_labels_and_probs, pred_time
# Initialize Gradio app
title = "FoodVision Mini"
description = "EfficientNetB2 feature extractor to classify images of food as pizza, steak, or sushi."
article = "From the [Zero to Mastery PyTorch tutorial](https://www.learnpytorch.io/09_pytorch_model_deployment/)"
examples = [[example] for example in Path("examples").glob("*.jpg")]
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction time (s)")],
examples=examples,
title=title,
description=description,
article=article,
)
demo.launch()
|