RockeyCoss
add code files”
51f6859
raw
history blame
4.6 kB
import cv2
import torch
import torch.nn as nn
from mmcv import Config
from mmcv.runner import load_checkpoint
from mmdet.core import bbox2result
from mmdet.models import DETECTORS, BaseDetector
from projects.instance_segment_anything.models.segment_anything import sam_model_registry, SamPredictor
from .focalnet_dino.focalnet_dino_wrapper import FocalNetDINOWrapper
from .hdetr.hdetr_wrapper import HDetrWrapper
@DETECTORS.register_module()
class DetWrapperInstanceSAM(BaseDetector):
wrapper_dict = {'hdetr': HDetrWrapper,
'focalnet_dino': FocalNetDINOWrapper}
def __init__(self,
det_wrapper_type='hdetr',
det_wrapper_cfg=None,
det_model_ckpt=None,
num_classes=80,
model_type='vit_b',
sam_checkpoint=None,
use_sam_iou=True,
init_cfg=None,
train_cfg=None,
test_cfg=None):
super(DetWrapperInstanceSAM, self).__init__(init_cfg)
self.learnable_placeholder = nn.Embedding(1, 1)
det_wrapper_cfg = Config(det_wrapper_cfg)
assert det_wrapper_type in self.wrapper_dict.keys()
self.det_model = self.wrapper_dict[det_wrapper_type](args=det_wrapper_cfg)
if det_model_ckpt is not None:
load_checkpoint(self.det_model.model,
filename=det_model_ckpt,
map_location='cpu')
self.num_classes = num_classes
# Segment Anything
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
_ = sam.to(device=self.learnable_placeholder.weight.device)
self.predictor = SamPredictor(sam)
self.use_sam_iou = use_sam_iou
def init_weights(self):
pass
def simple_test(self, img, img_metas, ori_img, rescale=True):
"""Test without augmentation.
Args:
imgs (Tensor): A batch of images.
img_metas (list[dict]): List of image information.
"""
assert rescale
assert len(img_metas) == 1
# results: List[dict(scores, labels, boxes)]
results = self.det_model.simple_test(img,
img_metas,
rescale)
# Tensor(n,4), xyxy, ori image scale
output_boxes = results[0]['boxes']
self.predictor.set_image(ori_img)
transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
# mask_pred: n,1,h,w
# sam_score: n, 1
mask_pred, sam_score, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
return_logits=True,
)
# Tensor(n,h,w), raw mask pred
mask_pred = mask_pred.squeeze(1)
sam_score = sam_score.squeeze(-1)
# Tensor(n,)
label_pred = results[0]['labels']
score_pred = results[0]['scores']
# mask_pred: Tensor(n,h,w)
# label_pred: Tensor(n,)
# score_pred: Tensor(n,)
# sam_score: Tensor(n,)
mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float()
if self.use_sam_iou:
det_scores = score_pred * sam_score
else:
# n
mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / (
mask_pred_binary.flatten(1).sum(1) + 1e-6)
det_scores = score_pred * mask_scores_per_image
# det_scores = score_pred
mask_pred_binary = mask_pred_binary.bool()
bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1)
bbox_results = bbox2result(bboxes, label_pred, self.num_classes)
mask_results = [[] for _ in range(self.num_classes)]
for j, label in enumerate(label_pred):
mask = mask_pred_binary[j].detach().cpu().numpy()
mask_results[label].append(mask)
output_results = [(bbox_results, mask_results)]
return output_results
# not implemented:
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError
def onnx_export(self, img, img_metas):
raise NotImplementedError
async def async_simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError
def extract_feat(self, imgs):
raise NotImplementedError