File size: 5,844 Bytes
51f6859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn.functional as F
from mmcv.runner import BaseModule

from .models import build_model
from .models.util.misc import NestedTensor, inverse_sigmoid


class HDetrWrapper(BaseModule):
    def __init__(self,
                 args=None,
                 init_cfg=None):
        super(HDetrWrapper, self).__init__(init_cfg)
        model, box_postprocessor = build_model(args)
        self.model = model
        self.box_postprocessor = box_postprocessor

        self.model.num_queries = self.model.num_queries_one2one
        self.model.transformer.two_stage_num_proposals = self.model.num_queries
        self.cls_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28,
                          31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54,
                          55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
                          82, 84, 85, 86, 87, 88, 89, 90]

    def forward(self,
                img,
                img_metas):
        """Forward function for training mode.
        Args:
            img (Tensor): of shape (N, C, H, W) encoding input images.
                Typically these should be mean centered and std scaled.
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
        """
        input_img_h, input_img_w = img_metas[0]["batch_input_shape"]
        batch_size = img.size(0)
        img_masks = img.new_ones((batch_size, input_img_h, input_img_w),
                                 dtype=torch.bool)
        for img_id in range(batch_size):
            img_h, img_w, _ = img_metas[img_id]["img_shape"]
            img_masks[img_id, :img_h, :img_w] = False
        samples = NestedTensor(tensors=img, mask=img_masks)
        features, pos = self.model.backbone(samples)

        srcs = []
        masks = []
        for l, feat in enumerate(features):
            src, mask = feat.decompose()
            srcs.append(self.model.input_proj[l](src))
            masks.append(mask)
            assert mask is not None
        if self.model.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)
            for l in range(_len_srcs, self.model.num_feature_levels):
                if l == _len_srcs:
                    src = self.model.input_proj[l](features[-1].tensors)
                else:
                    src = self.model.input_proj[l](srcs[-1])
                m = samples.mask
                mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
                    torch.bool
                )[0]
                pos_l = self.model.backbone[1](NestedTensor(src, mask)).to(src.dtype)
                srcs.append(src)
                masks.append(mask)
                pos.append(pos_l)

        query_embeds = None
        if not self.model.two_stage or self.model.mixed_selection:
            query_embeds = self.model.query_embed.weight[0: self.model.num_queries, :]

        # make attn mask
        """ attention mask to prevent information leakage
        """
        self_attn_mask = (
            torch.zeros([self.model.num_queries, self.model.num_queries, ]).bool().to(src.device)
        )
        self_attn_mask[self.model.num_queries_one2one:, 0: self.model.num_queries_one2one, ] = True
        self_attn_mask[0: self.model.num_queries_one2one, self.model.num_queries_one2one:, ] = True

        (
            hs,
            init_reference,
            inter_references,
            enc_outputs_class,
            enc_outputs_coord_unact,
        ) = self.model.transformer(srcs, masks, pos, query_embeds, self_attn_mask)

        outputs_classes_one2one = []
        outputs_coords_one2one = []
        outputs_classes_one2many = []
        outputs_coords_one2many = []
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.model.class_embed[lvl](hs[lvl])
            tmp = self.model.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[..., :2] += reference
            outputs_coord = tmp.sigmoid()

            outputs_classes_one2one.append(
                outputs_class[:, 0: self.model.num_queries_one2one]
            )
            outputs_classes_one2many.append(
                outputs_class[:, self.model.num_queries_one2one:]
            )
            outputs_coords_one2one.append(
                outputs_coord[:, 0: self.model.num_queries_one2one]
            )
            outputs_coords_one2many.append(outputs_coord[:, self.model.num_queries_one2one:])
        outputs_classes_one2one = torch.stack(outputs_classes_one2one)
        outputs_coords_one2one = torch.stack(outputs_coords_one2one)

        sampled_logits = outputs_classes_one2one[-1][:, :, self.cls_index]
        out = {
            "pred_logits": sampled_logits,
            "pred_boxes": outputs_coords_one2one[-1],
        }
        return out

    def simple_test(self, img, img_metas, rescale=False):
        # out: dict
        out = self(img, img_metas)
        if rescale:
            ori_target_sizes = [meta_info['ori_shape'][:2] for meta_info in img_metas]
        else:
            ori_target_sizes = [meta_info['img_shape'][:2] for meta_info in img_metas]
        ori_target_sizes = out['pred_logits'].new_tensor(ori_target_sizes, dtype=torch.int64)
        # results: List[dict(scores, labels, boxes)]
        results = self.box_postprocessor(out, ori_target_sizes)
        return results