Image_Dissector / app.py
wendys-llc's picture
Update app.py
de37a0c verified
raw
history blame
2.49 kB
from PIL import Image, ImageFilter
import numpy as np
from transformers import pipeline
import gradio as gr
import os
models = [
"facebook/detr-resnet-50-panoptic",
"CIDAS/clipseg-rd64-refined",
"facebook/maskformer-swin-large-ade",
"nvidia/segformer-b1-finetuned-cityscapes-1024-1024",
]
current_model = models[0]
#model = pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic")
pred = []
def img_resize(image):
width = 1280
width_percent = (width / float(image.size[0]))
height = int((float(image.size[1]) * float(width_percent)))
return image.resize((width, height))
def image_objects(image):
global pred
image = img_resize(image)
pred = model(image)
pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)]
return gr.Dropdown.update(choices = pred_object_list, interactive = True)
def get_seg(image, model_choice):
image = img_resize(image)
model = models[model_choice]
segment = pipeline("image-segmentation", model=f"{model}")
pred = segment(image)
pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)]
seg_box=[]
for i in range(len(pred)):
#object_number = int(object.split('_')[0])
mask_array = np.asarray(pred[i]['mask'])/255
image_array = np.asarray(image)
mask_array_three_channel = np.zeros_like(image_array)
mask_array_three_channel[:,:,0] = mask_array
mask_array_three_channel[:,:,1] = mask_array
mask_array_three_channel[:,:,2] = mask_array
segmented_image = image_array*mask_array_three_channel
seg_out=segmented_image.astype(np.uint8)
seg_box.append(seg_out)
return(seg_box,gr.Dropdown.update(choices = pred_object_list, interactive = True))
app = gr.Blocks()
with app:
gr.Markdown(
"""
## Image Dissector
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image",type="pil")
model_name = gr.Dropdown(show_label=False, choices=[m for m in models], type="index", value=current_model, interactive=True)
with gr.Column():
gal1=gr.Gallery(type="filepath").style(grid=6)
with gr.Row():
with gr.Column():
object_output = gr.Dropdown(label="Objects")
image_input.change(get_seg, inputs=[image_input, model_name], outputs=[gal1,object_output])
app.launch()