File size: 2,373 Bytes
04fe87b
6653d1e
04fe87b
 
6653d1e
04fe87b
6653d1e
04fe87b
6653d1e
 
 
 
 
 
 
04fe87b
 
 
9f94e87
 
 
6653d1e
9f94e87
04fe87b
 
 
8a67a54
 
 
 
 
 
 
 
 
 
 
 
 
04fe87b
 
ce69b66
04fe87b
 
 
 
ce69b66
04fe87b
 
 
ce69b66
04fe87b
 
8f70c7e
04fe87b
95f05af
82669d7
04fe87b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import json

# Load the model and image processor
model_name = "imjeffhi/pokemon_classifier"
model = ViTForImageClassification.from_pretrained(model_name)
image_processor = ViTImageProcessor.from_pretrained(model_name)

# Load the labels from the config.json file
with open('config.json', 'r') as f:
    config = json.load(f)
labels = config['id2label']

# Function to preprocess the image
def preprocess_image(img_pil):
    # Ensure img_pil is a PIL Image object
    if not isinstance(img_pil, Image.Image):
        img_pil = Image.fromarray(img_pil)  # Convert NumPy array to PIL Image
    inputs = image_processor(images=img_pil, return_tensors="pt")
    return inputs

# Function to predict the class of the image
def predict_classification(img_pil):
    try:
        inputs = preprocess_image(img_pil)
        pixel_values = inputs['pixel_values'][0]  # Extract the pixel_values tensor
        with torch.no_grad():
            outputs = model(pixel_values=pixel_values.unsqueeze(0))  # Add batch dimension
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        predicted_class = labels[str(predicted_class_idx)]
        confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx]
        return predicted_class, confidence
    except Exception as e: 
        return "Error occurred during prediction", 0.0 
    

# Function to handle the prediction in the Gradio interface
def pokémon_predict(img_pil):
    predicted_class, confidence = predict_classification(img_pil)
    return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}"

# Create Gradio interface
input_image = gr.Image(label="Upload an image of a Pokémon")
output_text = gr.Textbox(label="Predicted Class and Confidence")

iface = gr.Interface(
    fn=pokémon_predict,
    inputs=input_image,
    outputs=output_text,
    examples=["Farfetch'd.jpeg", "Mamoswine.jpeg", "Primeape.jpeg", "Oranguru.jpeg", "Slurrpuff.jpeg", "Vullaby.jpeg"],
    title=" Pokémon Classifier",
    description="Upload any image deriving from at least 898 species of Pokémon and the classifier will tell you which one it is and the confidence level of the prediction.",
    allow_flagging="never"
)

iface.launch()