import torch import torch.nn.functional as F from mmcv.runner import BaseModule from .models import build_model from .models.util.misc import NestedTensor, inverse_sigmoid class HDetrWrapper(BaseModule): def __init__(self, args=None, init_cfg=None): super(HDetrWrapper, self).__init__(init_cfg) model, box_postprocessor = build_model(args) self.model = model self.box_postprocessor = box_postprocessor self.model.num_queries = self.model.num_queries_one2one self.model.transformer.two_stage_num_proposals = self.model.num_queries self.cls_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] def forward(self, img, img_metas): """Forward function for training mode. Args: img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. """ input_img_h, input_img_w = img_metas[0]["batch_input_shape"] batch_size = img.size(0) img_masks = img.new_ones((batch_size, input_img_h, input_img_w), dtype=torch.bool) for img_id in range(batch_size): img_h, img_w, _ = img_metas[img_id]["img_shape"] img_masks[img_id, :img_h, :img_w] = False samples = NestedTensor(tensors=img, mask=img_masks) features, pos = self.model.backbone(samples) srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.model.input_proj[l](src)) masks.append(mask) assert mask is not None if self.model.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.model.num_feature_levels): if l == _len_srcs: src = self.model.input_proj[l](features[-1].tensors) else: src = self.model.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( torch.bool )[0] pos_l = self.model.backbone[1](NestedTensor(src, mask)).to(src.dtype) srcs.append(src) masks.append(mask) pos.append(pos_l) query_embeds = None if not self.model.two_stage or self.model.mixed_selection: query_embeds = self.model.query_embed.weight[0: self.model.num_queries, :] # make attn mask """ attention mask to prevent information leakage """ self_attn_mask = ( torch.zeros([self.model.num_queries, self.model.num_queries, ]).bool().to(src.device) ) self_attn_mask[self.model.num_queries_one2one:, 0: self.model.num_queries_one2one, ] = True self_attn_mask[0: self.model.num_queries_one2one, self.model.num_queries_one2one:, ] = True ( hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact, ) = self.model.transformer(srcs, masks, pos, query_embeds, self_attn_mask) outputs_classes_one2one = [] outputs_coords_one2one = [] outputs_classes_one2many = [] outputs_coords_one2many = [] for lvl in range(hs.shape[0]): if lvl == 0: reference = init_reference else: reference = inter_references[lvl - 1] reference = inverse_sigmoid(reference) outputs_class = self.model.class_embed[lvl](hs[lvl]) tmp = self.model.bbox_embed[lvl](hs[lvl]) if reference.shape[-1] == 4: tmp += reference else: assert reference.shape[-1] == 2 tmp[..., :2] += reference outputs_coord = tmp.sigmoid() outputs_classes_one2one.append( outputs_class[:, 0: self.model.num_queries_one2one] ) outputs_classes_one2many.append( outputs_class[:, self.model.num_queries_one2one:] ) outputs_coords_one2one.append( outputs_coord[:, 0: self.model.num_queries_one2one] ) outputs_coords_one2many.append(outputs_coord[:, self.model.num_queries_one2one:]) outputs_classes_one2one = torch.stack(outputs_classes_one2one) outputs_coords_one2one = torch.stack(outputs_coords_one2one) sampled_logits = outputs_classes_one2one[-1][:, :, self.cls_index] out = { "pred_logits": sampled_logits, "pred_boxes": outputs_coords_one2one[-1], } return out def simple_test(self, img, img_metas, rescale=False): # out: dict out = self(img, img_metas) if rescale: ori_target_sizes = [meta_info['ori_shape'][:2] for meta_info in img_metas] else: ori_target_sizes = [meta_info['img_shape'][:2] for meta_info in img_metas] ori_target_sizes = out['pred_logits'].new_tensor(ori_target_sizes, dtype=torch.int64) # results: List[dict(scores, labels, boxes)] results = self.box_postprocessor(out, ori_target_sizes) return results