|
import torch |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
from collections import defaultdict |
|
import os |
|
|
|
|
|
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') |
|
|
|
from detectron2.data import MetadataCatalog |
|
from detectron2.utils.visualizer import ColorMode, Visualizer |
|
from color_palette import ade_palette |
|
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation |
|
|
|
def load_model_and_processor(model_ckpt: str): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device)) |
|
model.eval() |
|
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt) |
|
return model, image_preprocessor |
|
|
|
def load_default_ckpt(segmentation_task: str): |
|
if segmentation_task == "semantic": |
|
default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic" |
|
elif segmentation_task == "instance": |
|
default_ckpt = "facebook/mask2former-swin-small-coco-instance" |
|
else: |
|
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic" |
|
return default_ckpt |
|
|
|
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image): |
|
metadata = MetadataCatalog.get("coco_2017_val_panoptic") |
|
for res in seg_info: |
|
res['category_id'] = res.pop('label_id') |
|
pred_class = res['category_id'] |
|
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() |
|
res['isthing'] = bool(isthing) |
|
|
|
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE) |
|
out = visualizer.draw_panoptic_seg_predictions( |
|
predicted_segmentation_map.cpu(), seg_info, alpha=0.5 |
|
) |
|
output_img = Image.fromarray(out.get_image()) |
|
return output_img |
|
|
|
def draw_semantic_segmentation(segmentation_map, image, palette): |
|
|
|
color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) |
|
for label, color in enumerate(palette): |
|
color_segmentation_map[segmentation_map - 1 == label, :] = color |
|
|
|
ground_truth_color_seg = color_segmentation_map[..., ::-1] |
|
|
|
img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5 |
|
img = img.astype(np.uint8) |
|
return img |
|
|
|
def visualize_instance_seg_mask(mask, input_image): |
|
color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) |
|
|
|
labels = np.unique(mask) |
|
label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels} |
|
|
|
for label, color in label2color.items(): |
|
color_segmentation_map[mask - 1 == label, :] = color |
|
|
|
ground_truth_color_seg = color_segmentation_map[..., ::-1] |
|
|
|
img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5 |
|
img = img.astype(np.uint8) |
|
return img |
|
|
|
def predict_masks(input_img_path: str, segmentation_task: str): |
|
|
|
|
|
default_ckpt = load_default_ckpt(segmentation_task) |
|
model, image_processor = load_model_and_processor(default_ckpt) |
|
|
|
|
|
image = Image.open(input_img_path) |
|
inputs = image_processor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
if segmentation_task == "semantic": |
|
result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
|
predicted_segmentation_map = result.cpu().numpy() |
|
palette = ade_palette() |
|
output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette) |
|
output_heading = "Semantic Segmentation Output" |
|
|
|
elif segmentation_task == "instance": |
|
result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
|
predicted_instance_map = result["segmentation"].cpu().detach().numpy() |
|
output_result = visualize_instance_seg_mask(predicted_instance_map, image) |
|
output_heading = "Instance Segmentation Output" |
|
|
|
else: |
|
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
|
predicted_segmentation_map = result["segmentation"] |
|
seg_info = result['segments_info'] |
|
output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image) |
|
output_heading = "Panoptic Segmentation Output" |
|
|
|
|
|
return output_result, output_heading |
|
|