import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch import json # Load the model and image processor model_name = "imjeffhi/pokemon_classifier" model = ViTForImageClassification.from_pretrained(model_name) image_processor = ViTImageProcessor.from_pretrained(model_name) # Load the labels from the config.json file with open('config.json', 'r') as f: config = json.load(f) labels = config['id2label'] # Function to preprocess the image def preprocess_image(img_pil): # Ensure img_pil is a PIL Image object if not isinstance(img_pil, Image.Image): img_pil = Image.fromarray(img_pil) # Convert NumPy array to PIL Image inputs = image_processor(images=img_pil, return_tensors="pt") return inputs # Function to predict the class of the image def predict_classification(img_pil): try: inputs = preprocess_image(img_pil) pixel_values = inputs['pixel_values'][0] # Extract the pixel_values tensor with torch.no_grad(): outputs = model(pixel_values=pixel_values.unsqueeze(0)) # Add batch dimension logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() predicted_class = labels[str(predicted_class_idx)] confidence = torch.nn.functional.softmax(logits, dim=1).numpy()[0][predicted_class_idx] return predicted_class, confidence except Exception as e: return "Error occurred during prediction", 0.0 # Function to handle the prediction in the Gradio interface def pokémon_predict(img_pil): predicted_class, confidence = predict_classification(img_pil) return f"Predicted class: {predicted_class}, Confidence: {confidence:.4f}" # Create Gradio interface input_image = gr.Image(label="Upload an image of a Pokémon") output_text = gr.Textbox(label="Predicted Class and Confidence") iface = gr.Interface( fn=pokémon_predict, inputs=input_image, outputs=output_text, examples=["Farfetch'd.jpeg", "Mamoswine.jpeg", "Primeape.jpeg", "Oranguru.jpeg", "Slurrpuff.jpeg", "Vullaby.jpeg"], title=" Pokémon Classifier", description="Upload any image deriving from at least 898 species of Pokémon and the classifier will tell you which one it is and the confidence level of the prediction.", allow_flagging="never" ) iface.launch()