TroglodyteDerivations's picture
Update app.py
95f05af verified
raw
history blame
2.37 kB
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import json
# Load the model and image processor
model_name = "imjeffhi/pokemon_classifier"
model = ViTForImageClassification.from_pretrained(model_name)
image_processor = ViTImageProcessor.from_pretrained(model_name)
# Load the labels from the config.json file
with open('config.json', 'r') as f:
config = json.load(f)
labels = config['id2label']
# Function to preprocess the image
def preprocess_image(img_pil):
# Ensure img_pil is a PIL Image object
if not isinstance(img_pil, Image.Image):
img_pil = Image.fromarray(img_pil) # Convert NumPy array to PIL Image
inputs = image_processor(images=img_pil, return_tensors="pt")
return inputs
# Function to predict the class of the image
def predict_classification(img_pil):
try:
inputs = preprocess_image(img_pil)
pixel_values = inputs['pixel_values'][0] # Extract the pixel_values tensor
with torch.no_grad():
outputs = model(pixel_values=pixel_values.unsqueeze(0)) # Add batch dimension
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = labels[str(predicted_class_idx)]
confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
return predicted_class, confidence
except Exception as e:
return "Error occurred during prediction", 0.0
# Function to handle the prediction in the Gradio interface
def pokémon_predict(img_pil):
predicted_class, confidence = predict_classification(img_pil)
return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}"
# Create Gradio interface
input_image = gr.Image(label="Upload an image of a Pokémon")
output_text = gr.Textbox(label="Predicted Class and Confidence")
iface = gr.Interface(
fn=pokémon_predict,
inputs=input_image,
outputs=output_text,
examples=["Farfetch'd.jpeg", "Mamoswine.jpeg", "Primeape.jpeg", "Oranguru.jpeg", "Slurrpuff.jpeg", "Vullaby.jpeg"],
title=" Pokémon Classifier",
description="Upload any image deriving from at least 898 species of Pokémon and the classifier will tell you which one it is and the confidence level of the prediction.",
allow_flagging="never"
)
iface.launch()