mask2former-demo / predict.py
merve's picture
merve HF staff
Duplicate from shivi/mask2former-demo
7bc0bed
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
import os
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
# Hence, installing detectron2 this way when using Gradio HF spaces.
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
def load_model_and_processor(model_ckpt: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
model.eval()
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
return model, image_preprocessor
def load_default_ckpt(segmentation_task: str):
if segmentation_task == "semantic":
default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
elif segmentation_task == "instance":
default_ckpt = "facebook/mask2former-swin-small-coco-instance"
else:
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
return default_ckpt
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
for res in seg_info:
res['category_id'] = res.pop('label_id')
pred_class = res['category_id']
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
res['isthing'] = bool(isthing)
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
out = visualizer.draw_panoptic_seg_predictions(
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
)
output_img = Image.fromarray(out.get_image())
return output_img
def draw_semantic_segmentation(segmentation_map, image, palette):
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in enumerate(palette):
color_segmentation_map[segmentation_map - 1 == label, :] = color
# Convert to BGR
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def visualize_instance_seg_mask(mask, input_image):
color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
labels = np.unique(mask)
label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}
for label, color in label2color.items():
color_segmentation_map[mask - 1 == label, :] = color
ground_truth_color_seg = color_segmentation_map[..., ::-1]
img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)
return img
def predict_masks(input_img_path: str, segmentation_task: str):
#load model and image processor
default_ckpt = load_default_ckpt(segmentation_task)
model, image_processor = load_model_and_processor(default_ckpt)
## pass input image through image processor
image = Image.open(input_img_path)
inputs = image_processor(images=image, return_tensors="pt")
## pass inputs to model for prediction
with torch.no_grad():
outputs = model(**inputs)
# pass outputs to processor for postprocessing
if segmentation_task == "semantic":
result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result.cpu().numpy()
palette = ade_palette()
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
output_heading = "Semantic Segmentation Output"
elif segmentation_task == "instance":
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_instance_map = result["segmentation"].cpu().detach().numpy()
output_result = visualize_instance_seg_mask(predicted_instance_map, image)
output_heading = "Instance Segmentation Output"
else:
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
predicted_segmentation_map = result["segmentation"]
seg_info = result['segments_info']
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
output_heading = "Panoptic Segmentation Output"
return output_result, output_heading