import gradio as gr from transformers import AutoModelForImageClassification, AutoFeatureExtractor from PIL import Image import torch # Load the ResNet-50 model and feature extractor model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50") feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") # Define the prediction function def predict(image): try: # Ensure the image is in PIL format if not isinstance(image, Image.Image): return "Invalid image format. Please upload a valid image." # Preprocess the image using the feature extractor inputs = feature_extractor(images=image, return_tensors="pt") # Perform inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get the class with the highest score predicted_class_idx = logits.argmax(-1).item() # Map the predicted index to its human-readable label predicted_class_label = model.config.id2label[predicted_class_idx] return f"Predicted class: {predicted_class_label}" except Exception as e: return f"Error: {str(e)}" # Create the Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Text(label="Prediction"), title="ResNet-50 Image Classification", description="Upload an image to classify it into one of the ImageNet classes using the ResNet-50 model." ) # Launch the app if __name__ == "__main__": interface.launch()