brendenc's picture
Update app.py
b17b4ee
raw
history blame contribute delete
No virus
1.03 kB
import gradio as gr
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import matplotlib.pyplot as plt
extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
def classify(im):
inputs = extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
classes = logits[0].detach().numpy().argmax(axis=0)
colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
return colors[classes]
example_imgs = [f"example_{i}.jpg" for i in range(3)]
interface = gr.Interface(classify,
inputs="image",
outputs="image",
examples = example_imgs,
title = "Street Image Segmentation",
description = """Below is a simple app for image segmentation.""")
interface.launch(debug=True)