ViT_Team-A / app.py
iamomtiwari's picture
Update app.py
7d859e4 verified
raw
history blame
1.6 kB
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import torch
# Load the ResNet-50 model and feature extractor
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
# Define the prediction function
def predict(image):
try:
# Ensure the image is in PIL format
if not isinstance(image, Image.Image):
return "Invalid image format. Please upload a valid image."
# Preprocess the image using the feature extractor
inputs = feature_extractor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get the class with the highest score
predicted_class_idx = logits.argmax(-1).item()
# Map the predicted index to its human-readable label
predicted_class_label = model.config.id2label[predicted_class_idx]
return f"Predicted class: {predicted_class_label}"
except Exception as e:
return f"Error: {str(e)}"
# Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Text(label="Prediction"),
title="ResNet-50 Image Classification",
description="Upload an image to classify it into one of the ImageNet classes using the ResNet-50 model."
)
# Launch the app
if __name__ == "__main__":
interface.launch()