iamomtiwari commited on
Commit
7d859e4
1 Parent(s): b28200c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -1,44 +1,47 @@
1
  import gradio as gr
2
- from transformers import ViTImageProcessor, ViTForImageClassification
3
  from PIL import Image
4
  import torch
5
 
6
- # Load the pre-trained ViT model and processor
7
- processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') # Using the in21k pre-trained model
8
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
9
 
10
- # Inference function for predicting with ViT model
11
  def predict(image):
12
  try:
13
  # Ensure the image is in PIL format
14
- if isinstance(image, str):
15
- image = Image.open(image)
16
-
17
- # Preprocess the input image using the processor, with padding enabled
18
- inputs = processor(images=image, return_tensors="pt", padding=True)
19
 
20
- # Get the model's predictions
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
  logits = outputs.logits
24
-
25
- # Get the predicted class index (class with the highest logit value)
26
  predicted_class_idx = logits.argmax(-1).item()
27
 
28
- # Get the human-readable label for the predicted class
29
  predicted_class_label = model.config.id2label[predicted_class_idx]
30
-
31
  return f"Predicted class: {predicted_class_label}"
32
 
33
  except Exception as e:
34
  return f"Error: {str(e)}"
35
 
36
- # Create Gradio Interface (Note the change here: `gr.Image` and `gr.Text`)
37
- interface = gr.Interface(fn=predict,
38
- inputs=gr.Image(type="pil", label="Upload Image"),
39
- outputs=gr.Text(),
40
- title="ViT Image Classification (ImageNet)",
41
- description="Upload an image to classify it into one of the 1000 ImageNet classes.")
 
 
42
 
43
- # Launch the interface
44
- interface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
  from PIL import Image
4
  import torch
5
 
6
+ # Load the ResNet-50 model and feature extractor
7
+ model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
8
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
9
 
10
+ # Define the prediction function
11
  def predict(image):
12
  try:
13
  # Ensure the image is in PIL format
14
+ if not isinstance(image, Image.Image):
15
+ return "Invalid image format. Please upload a valid image."
16
+
17
+ # Preprocess the image using the feature extractor
18
+ inputs = feature_extractor(images=image, return_tensors="pt")
19
 
20
+ # Perform inference
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
  logits = outputs.logits
24
+
25
+ # Get the class with the highest score
26
  predicted_class_idx = logits.argmax(-1).item()
27
 
28
+ # Map the predicted index to its human-readable label
29
  predicted_class_label = model.config.id2label[predicted_class_idx]
30
+
31
  return f"Predicted class: {predicted_class_label}"
32
 
33
  except Exception as e:
34
  return f"Error: {str(e)}"
35
 
36
+ # Create the Gradio interface
37
+ interface = gr.Interface(
38
+ fn=predict,
39
+ inputs=gr.Image(type="pil", label="Upload Image"),
40
+ outputs=gr.Text(label="Prediction"),
41
+ title="ResNet-50 Image Classification",
42
+ description="Upload an image to classify it into one of the ImageNet classes using the ResNet-50 model."
43
+ )
44
 
45
+ # Launch the app
46
+ if __name__ == "__main__":
47
+ interface.launch()