Spaces:
Runtime error
Runtime error
import numpy as np | |
import cv2 | |
import gradio as gr | |
import torch | |
from ade20k_colors import colors | |
from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation | |
beit_models = ['microsoft/beit-base-finetuned-ade-640-640', | |
'microsoft/beit-large-finetuned-ade-640-640'] | |
models = [BeitForSemanticSegmentation.from_pretrained(m) for m in beit_models] | |
extractors = [BeitFeatureExtractor.from_pretrained(m) for m in beit_models] | |
def apply_colors(img): | |
ret = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) | |
for y in range(img.shape[0]): | |
for x in range(img.shape[1]): | |
ret[y,x] = colors[np.argmax(img[y,x])] | |
return ret | |
def inference(image, chosen_model): | |
feature_extractor = extractors[chosen_model] | |
model = models[chosen_model] | |
inputs = feature_extractor(images=image, return_tensors='pt') | |
outputs = model(**inputs) | |
logits = outputs.logits | |
output = torch.sigmoid(logits).detach().numpy()[0] | |
output = np.transpose(output, (1,2,0)) | |
output = apply_colors(output) | |
return cv2.resize(output, image.shape[1::-1]) | |
inputs = [gr.inputs.Image(label='Input Image'), | |
gr.inputs.Radio(['Base', 'Large'], label='BEiT Model', type='index')] | |
gr.Interface( | |
inference, | |
inputs, | |
gr.outputs.Image(label='Output'), | |
title='BEiT - Semantic Segmentation', | |
description='BEIT: BERT Pre-Training of Image Transformers', | |
examples=[['images/armchair.jpg', 'Base'], | |
['images/cat.jpg', 'Base'], | |
['images/plant.jpg', 'Large']] | |
).launch() |