Megareyka's picture
Update app.py
580e36c
raw
history blame
735 Bytes
from transformers import pipeline
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image as img
import numpy as np
import gradio as gr
featureextractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
def classify(input_img):
filename = input_img
imagearray = input_img
inputs = featureextractor(images = imagearray, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]
demo = gr.Interface(fn=classify, inputs="image", outputs="text")
demo.launch()