File size: 3,837 Bytes
c310e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""

import torch
from torch import nn

from maskrcnn_benchmark.structures.image_list import to_image_list

from ..backbone import build_backbone
from ..rpn.rpn import build_rpn
from ..segmentation.segmentation import build_segmentation
from ..roi_heads.roi_heads import build_roi_heads
import time

class GeneralizedRCNN(nn.Module):
    """
    Main class for Generalized R-CNN. Currently supports boxes and masks.
    It consists of three main parts:
    - backbone
    = rpn
    - heads: takes the features + the proposals from the RPN and computes
        detections / masks from it.
    """

    def __init__(self, cfg):
        super(GeneralizedRCNN, self).__init__()
        self.cfg = cfg
        self.backbone = build_backbone(cfg)
        if cfg.MODEL.SEG_ON:
            self.proposal = build_segmentation(cfg)
        else:
            self.proposal = build_rpn(cfg)
        if cfg.MODEL.TRAIN_DETECTION_ONLY:
            self.roi_heads = None
        else:
            self.roi_heads = build_roi_heads(cfg)

    def forward(self, images, targets=None):
        """
        Arguments:
            images (list[Tensor] or ImageList): images to be processed
            targets (list[BoxList]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        # torch.cuda.synchronize()
        # start_time = time.time()
        images = to_image_list(images)
        # torch.cuda.synchronize()
        # end_time = time.time()
        # print('image load time:', end_time - start_time)
        # torch.cuda.synchronize()
        # start_time = time.time()
        features = self.backbone(images.tensors)
        # torch.cuda.synchronize()
        # end_time = time.time()
        # print('backbone time:', end_time - start_time)
        if self.cfg.MODEL.SEG_ON and not self.training:
            # torch.cuda.synchronize()
            # start_time = time.time()
            (proposals, seg_results), fuse_feature = self.proposal(images, features, targets)
            # torch.cuda.synchronize()
            # end_time = time.time()
            # print('seg time:', end_time - start_time)
        else:
            if self.cfg.MODEL.SEG_ON:
                (proposals, proposal_losses), fuse_feature = self.proposal(images, features, targets)
            else:
                proposals, proposal_losses = self.proposal(images, features, targets)
        if self.roi_heads is not None:
            if self.cfg.MODEL.SEG_ON and self.cfg.MODEL.SEG.USE_FUSE_FEATURE:
                x, result, detector_losses = self.roi_heads(fuse_feature, proposals, targets)
            else:
                x, result, detector_losses = self.roi_heads(features, proposals, targets)
        else:
            # RPN-only models don't have roi_heads
            # x = features
            result = proposals
            detector_losses = {}

        if self.training:
            losses = {}
            if self.roi_heads is not None:
                losses.update(detector_losses)
            losses.update(proposal_losses)
            return losses
        else:
            if self.cfg.MODEL.SEG_ON:
                return result, proposals, seg_results
            else:
                return result

        # return result