iamomtiwari commited on
Commit
c02f406
1 Parent(s): 5bfe97a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -0
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
4
+ from PIL import Image
5
+
6
+ # Load ViT 221k model for image classification (pre-trained on ImageNet21k)
7
+ vit_221k_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
8
+ vit_221k_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
9
+
10
+ # Inference function for predicting with ViT 221k
11
+ def predict(image):
12
+ # Preprocess the input image
13
+ inputs = vit_221k_feature_extractor(images=image, return_tensors="pt")
14
+ with torch.no_grad():
15
+ outputs = vit_221k_model(**inputs)
16
+ predicted_class_idx = outputs.logits.argmax(-1).item()
17
+
18
+ # Get the label corresponding to the prediction
19
+ vit_221k_label = vit_221k_model.config.id2label[predicted_class_idx]
20
+ return f"Prediction from ViT 221k Model: {vit_221k_label}"
21
+
22
+ # Create Gradio Interface
23
+ interface = gr.Interface(fn=predict, inputs="image", outputs="text", title="ViT 221k Image Classification")
24
+
25
+ # Launch the interface
26
+ interface.launch()