File size: 2,059 Bytes
5f37d56
 
 
 
278b1ca
5f37d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123e402
5f37d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9486004
5f37d56
 
 
 
 
 
9486004
5f37d56
 
 
 
 
9486004
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from timeit import default_timer as timer
from typing import Tuple
from pathlib import Path
from PIL import Image

import gradio as gr
import torch
from torch import nn
from torchvision import transforms

from model import create_effnetb2_model

class_names = ["pizza", "steak", "sushi"]
device = "cpu"

# Create model
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))

# Load saved weights
effnetb2.load_state_dict(torch.load("effnetb2.pth", map_location=torch.device(device)))

# Define predict function
def predict(img: Image) -> Tuple[dict, float]:
    """Uses EffnetB2 model to transform and predict on img. Returns prediction
    probabilities and time taken.
    
    Args:
      img (PIL.Image): Image to predict on.
    
    Returns:
      A tuple (pred_labels_and_probs, pred_time), where pred_labels_and_probs
      is a dict mapping each class name to the probability the model assigns to
      it, and pred_time is the time taken to predict (in seconds).
    """
    start_time = timer()
    img = effnetb2_transforms(img).unsqueeze(0)
    effnetb2.eval()
    with torch.inference_mode():
        pred_probs = torch.softmax(effnetb2(img), dim=1)
    pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i])
                             for i in range(len(class_names))}
    pred_time = round(timer() - start_time, 4)
    return pred_labels_and_probs, pred_time

# Initialize Gradio app
title = "FoodVision Mini"
description = "EfficientNetB2 feature extractor to classify images of food as pizza, steak, or sushi."
article = "From the [Zero to Mastery PyTorch tutorial](https://www.learnpytorch.io/09_pytorch_model_deployment/)"
examples = [[example] for example in Path("examples").glob("*.jpg")]

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Label(num_top_classes=3, label="Predictions"),
             gr.Number(label="Prediction time (s)")], 
    examples=examples,
    title=title,
    description=description,
    article=article,
)

demo.launch()