import os from timeit import default_timer as timer from typing import Tuple from pathlib import Path from PIL import Image import gradio as gr import torch from torch import nn from torchvision import transforms from model import create_effnetb3_model class_names = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'] device = "cpu" # Create model effnetb3, effnetb3_transforms = create_effnetb3_model(num_classes=len(class_names)) # Load saved weights effnetb3_state_dict = torch.load("effnetb3_full_food101.pth", map_location=torch.device(device)) effnetb3_state_dict['classifier.1.weight'] = effnetb3_state_dict.pop('classifier.weight') effnetb3_state_dict['classifier.1.bias'] = effnetb3_state_dict.pop('classifier.bias') effnetb3.load_state_dict(effnetb3_state_dict) effnetb3.to(device); # Define predict function def predict(img: Image) -> Tuple[dict, float]: """Uses EffnetB3 model to transform and predict on img. Returns prediction probabilities and time taken. Args: img (PIL.Image): Image to predict on. Returns: A tuple (pred_labels_and_probs, pred_time), where pred_labels_and_probs is a dict mapping each class name to the probability the model assigns to it, and pred_time is the time taken to predict (in seconds). """ start_time = timer() img = effnetb3_transforms(img).unsqueeze(0) effnetb3.eval() with torch.inference_mode(): pred_probs = torch.softmax(effnetb3(img), dim=1) pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} pred_time = round(timer() - start_time, 4) return pred_labels_and_probs, pred_time # Initialize Gradio app title = "FoodVision" description = "EfficientNetB3 feature extractor to classify images of food. Upload an image or click on one of the examples to try it out!" article = """ From the [Zero to Mastery PyTorch tutorial](https://www.learnpytorch.io/09_pytorch_model_deployment/), using the [Food-101 dataset](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/). """ examples = [[example] for example in Path("examples").glob("*.jpg")] demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=3, label="Predictions"), gr.Number(label="Prediction time (s)")], examples=examples, title=title, description=description, article=article, ) demo.launch()