lion_cheetah / app.py
PhuongPhan's picture
Update app.py
29a6de5 verified
# pretrained Resnet-18 mode
import torch
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
# define a function that takes in the user input, which in this case is an image, and returns the prediction.
'''The prediction should be returned as a dictionary whose keys are class name and values are confidence probabilities.
We will load the class names from this text file.
'''
import requests
from PIL import Image
from torchvision import transforms
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
'''The function converts the input image into a PIL Image and subsequently into a PyTorch tensor.
After processing the tensor through the model, it returns the predictions in the form of a dictionary named confidences.
The dictionary's keys are the class labels, and its values are the corresponding confidence probabilities.
In this section, we define a predict function that processes an input image to return prediction probabilities.
The function first converts the image into a PyTorch tensor and then forwards it through the pretrained model.
We use the softmax function in the final step to calculate the probabilities of each class.
The softmax function is crucial because it converts the raw output logits from the model, which can be any real number, into probabilities that sum up to 1.
This makes it easier to interpret the model’s outputs as confidence levels for each class.'''
# Creating a Gradio interface
import gradio as gr
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"), # creates the component and handles the preprocessing to convert that to a PIL image
outputs=gr.Label(num_top_classes=3), #a Label, which displays the top labels in a nice form. Since we don't want to show all 1,000 class labels, we will customize it to show only the top 3 images by constructing it as
examples=["lion.jpg", "cheetah.jpg"]).launch()