|
import gradio as gr |
|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
from PIL import Image |
|
import torch |
|
import json |
|
|
|
|
|
model_name = "imjeffhi/pokemon_classifier" |
|
model = ViTForImageClassification.from_pretrained(model_name) |
|
image_processor = ViTImageProcessor.from_pretrained(model_name) |
|
|
|
|
|
with open('config.json', 'r') as f: |
|
config = json.load(f) |
|
labels = config['id2label'] |
|
|
|
|
|
def preprocess_image(img_pil): |
|
|
|
if not isinstance(img_pil, Image.Image): |
|
img_pil = Image.fromarray(img_pil) |
|
inputs = image_processor(images=img_pil, return_tensors="pt") |
|
return inputs |
|
|
|
|
|
def predict_classification(img_pil): |
|
try: |
|
inputs = preprocess_image(img_pil) |
|
pixel_values = inputs['pixel_values'][0] |
|
with torch.no_grad(): |
|
outputs = model(pixel_values=pixel_values.unsqueeze(0)) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
predicted_class = labels[str(predicted_class_idx)] |
|
confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx] |
|
return predicted_class, confidence |
|
except Exception as e: |
|
return "Error occurred during prediction", 0.0 |
|
|
|
|
|
|
|
def pokémon_predict(img_pil): |
|
predicted_class, confidence = predict_classification(img_pil) |
|
return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}" |
|
|
|
|
|
input_image = gr.Image(label="Upload an image of a Pokémon") |
|
output_text = gr.Textbox(label="Predicted Class and Confidence") |
|
|
|
iface = gr.Interface( |
|
fn=pokémon_predict, |
|
inputs=input_image, |
|
outputs=output_text, |
|
examples=["Farfetch'd.jpeg", "Mamoswine.jpeg", "Primeape.jpeg", "Oranguru.jpeg", "Slurrpuff.jpeg", "Vullaby.jpeg"], |
|
title=" Pokémon Classifier", |
|
description="Upload any image deriving from at least 898 species of Pokémon and the classifier will tell you which one it is and the confidence level of the prediction.", |
|
allow_flagging="never" |
|
) |
|
|
|
iface.launch() |