|
|
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
from transformers import SamModel, SamProcessor |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77") |
|
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77") |
|
|
|
|
|
|
|
input_points = [] |
|
|
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), |
|
np.array([0.6])], |
|
axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
def get_pixel_coordinates(image, evt: gr.SelectData): |
|
global input_points |
|
x, y = evt.index[0], evt.index[1] |
|
input_points = [[[x, y]]] |
|
return perform_prediction(image) |
|
|
|
|
|
def perform_prediction(image): |
|
global input_points |
|
|
|
inputs = processor(images=image, input_points=input_points, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
iou = outputs.iou_scores |
|
max_iou_index = torch.argmax(iou) |
|
|
|
|
|
predicted_masks = processor.image_processor.post_process_masks( |
|
outputs.pred_masks, |
|
inputs['original_sizes'], |
|
inputs['reshaped_input_sizes'] |
|
) |
|
predicted_mask = predicted_masks[0] |
|
|
|
|
|
mask_image = show_mask_on_image(image, predicted_mask[:,max_iou_index], return_image=True) |
|
return mask_image |
|
|
|
|
|
def show_mask_on_image(raw_image, mask, return_image=False): |
|
if not isinstance(mask, torch.Tensor): |
|
mask = torch.Tensor(mask) |
|
|
|
if len(mask.shape) == 4: |
|
mask = mask.squeeze() |
|
|
|
fig, axes = plt.subplots(1, 1, figsize=(15, 15)) |
|
|
|
mask = mask.cpu().detach() |
|
axes.imshow(np.array(raw_image)) |
|
show_mask(mask, axes) |
|
axes.axis("off") |
|
plt.show() |
|
|
|
if return_image: |
|
fig = plt.gcf() |
|
fig.canvas.draw() |
|
|
|
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
img = Image.fromarray(img) |
|
plt.close(fig) |
|
return img |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
<div style='text-align: center; font-family: "Times New Roman";'> |
|
<h1 style='color: #FF6347;'>One Click Image Segmentation App</h1> |
|
<h3 style='color: #4682B4;'>Model: SlimSAM-uniform-77</h3> |
|
<h3 style='color: #32CD32;'>Made By: Md. Mahmudun Nabi</h3> |
|
</div> |
|
""" |
|
) |
|
with gr.Row(): |
|
|
|
img = gr.Image(type="pil", label="Input Image",height=400, width=600) |
|
output_image = gr.Image(label="Masked Image") |
|
|
|
img.select(get_pixel_coordinates, inputs=[img], outputs=[output_image]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |