Spaces:
Runtime error
Runtime error
Prompt-Segment-Anything-Demo
/
projects
/instance_segment_anything
/models
/det_wrapper_instance_sam.py
import cv2 | |
import torch | |
import torch.nn as nn | |
from mmcv import Config | |
from mmcv.runner import load_checkpoint | |
from mmdet.core import bbox2result | |
from mmdet.models import DETECTORS, BaseDetector | |
from projects.instance_segment_anything.models.segment_anything import sam_model_registry, SamPredictor | |
from .focalnet_dino.focalnet_dino_wrapper import FocalNetDINOWrapper | |
from .hdetr.hdetr_wrapper import HDetrWrapper | |
class DetWrapperInstanceSAM(BaseDetector): | |
wrapper_dict = {'hdetr': HDetrWrapper, | |
'focalnet_dino': FocalNetDINOWrapper} | |
def __init__(self, | |
det_wrapper_type='hdetr', | |
det_wrapper_cfg=None, | |
det_model_ckpt=None, | |
num_classes=80, | |
model_type='vit_b', | |
sam_checkpoint=None, | |
use_sam_iou=True, | |
init_cfg=None, | |
train_cfg=None, | |
test_cfg=None): | |
super(DetWrapperInstanceSAM, self).__init__(init_cfg) | |
self.learnable_placeholder = nn.Embedding(1, 1) | |
det_wrapper_cfg = Config(det_wrapper_cfg) | |
assert det_wrapper_type in self.wrapper_dict.keys() | |
self.det_model = self.wrapper_dict[det_wrapper_type](args=det_wrapper_cfg) | |
if det_model_ckpt is not None: | |
load_checkpoint(self.det_model.model, | |
filename=det_model_ckpt, | |
map_location='cpu') | |
self.num_classes = num_classes | |
# Segment Anything | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
_ = sam.to(device=self.learnable_placeholder.weight.device) | |
self.predictor = SamPredictor(sam) | |
self.use_sam_iou = use_sam_iou | |
def init_weights(self): | |
pass | |
def simple_test(self, img, img_metas, ori_img, rescale=True): | |
"""Test without augmentation. | |
Args: | |
imgs (Tensor): A batch of images. | |
img_metas (list[dict]): List of image information. | |
""" | |
assert rescale | |
assert len(img_metas) == 1 | |
# results: List[dict(scores, labels, boxes)] | |
results = self.det_model.simple_test(img, | |
img_metas, | |
rescale) | |
# Tensor(n,4), xyxy, ori image scale | |
output_boxes = results[0]['boxes'] | |
self.predictor.set_image(ori_img) | |
transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2]) | |
# mask_pred: n,1,h,w | |
# sam_score: n, 1 | |
mask_pred, sam_score, _ = self.predictor.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=False, | |
return_logits=True, | |
) | |
# Tensor(n,h,w), raw mask pred | |
mask_pred = mask_pred.squeeze(1) | |
sam_score = sam_score.squeeze(-1) | |
# Tensor(n,) | |
label_pred = results[0]['labels'] | |
score_pred = results[0]['scores'] | |
# mask_pred: Tensor(n,h,w) | |
# label_pred: Tensor(n,) | |
# score_pred: Tensor(n,) | |
# sam_score: Tensor(n,) | |
mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float() | |
if self.use_sam_iou: | |
det_scores = score_pred * sam_score | |
else: | |
# n | |
mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / ( | |
mask_pred_binary.flatten(1).sum(1) + 1e-6) | |
det_scores = score_pred * mask_scores_per_image | |
# det_scores = score_pred | |
mask_pred_binary = mask_pred_binary.bool() | |
bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1) | |
bbox_results = bbox2result(bboxes, label_pred, self.num_classes) | |
mask_results = [[] for _ in range(self.num_classes)] | |
for j, label in enumerate(label_pred): | |
mask = mask_pred_binary[j].detach().cpu().numpy() | |
mask_results[label].append(mask) | |
output_results = [(bbox_results, mask_results)] | |
return output_results | |
# not implemented: | |
def aug_test(self, imgs, img_metas, **kwargs): | |
raise NotImplementedError | |
def onnx_export(self, img, img_metas): | |
raise NotImplementedError | |
async def async_simple_test(self, img, img_metas, **kwargs): | |
raise NotImplementedError | |
def forward_train(self, imgs, img_metas, **kwargs): | |
raise NotImplementedError | |
def extract_feat(self, imgs): | |
raise NotImplementedError | |