File size: 9,452 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
Finetune a pre-trained model on a downstream task, one of those available in
Detectron2.
Supported downstream:
  - LVIS Instance Segmentation
  - COCO Instance Segmentation
  - Pascal VOC 2007+12 Object Detection

Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
Thanks to the developers of Detectron2!
"""
import argparse
import os
import re
from typing import Any, Dict, Union

import torch
from torch.utils.tensorboard import SummaryWriter

import detectron2 as d2
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultTrainer, default_setup
from detectron2.evaluation import (
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    COCOEvaluator,
)
from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads

from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser
import virtex.utils.distributed as dist

# fmt: off
parser = common_parser(
    description="Train object detectors from pretrained visual backbone."
)
parser.add_argument(
    "--d2-config", required=True,
    help="Path to a detectron2 config for downstream task finetuning."
)
parser.add_argument(
    "--d2-config-override", nargs="*", default=[],
    help="""Key-value pairs from Detectron2 config to override from file.
    Some keys will be ignored because they are set from other args:
    [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,
    TEST.EVAL_PERIOD, OUTPUT_DIR]""",
)

parser.add_argument_group("Checkpointing and Logging")
parser.add_argument(
    "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
    default="virtex", help="""How to initialize weights:
        1. 'random' initializes all weights randomly
        2. 'imagenet' initializes backbone weights from torchvision model zoo
        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
            - with 'torchvision', state dict would be from PyTorch's training
              script.
            - with 'virtex' it should be for our full pretrained model."""
)
parser.add_argument(
    "--checkpoint-path",
    help="Path to load checkpoint and run downstream task evaluation."
)
parser.add_argument(
    "--resume", action="store_true", help="""Specify this flag when resuming
    training from a checkpoint saved by Detectron2."""
)
parser.add_argument(
    "--eval-only", action="store_true",
    help="Skip training and evaluate checkpoint provided at --checkpoint-path.",
)
parser.add_argument(
    "--checkpoint-every", type=int, default=5000,
    help="Serialize model to a checkpoint after every these many iterations.",
)
# fmt: on


@ROI_HEADS_REGISTRY.register()
class Res5ROIHeadsExtraNorm(Res5ROIHeads):
    r"""
    ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN
    C4/DC5 backbones for VOC detection.
    """

    def _build_res5_block(self, cfg):
        seq, out_channels = super()._build_res5_block(cfg)
        norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)
        seq.add_module("norm", norm)
        return seq, out_channels


def build_detectron2_config(_C: Config, _A: argparse.Namespace):
    r"""Build detectron2 config based on our pre-training config and args."""
    _D2C = d2.config.get_cfg()

    # Override some default values based on our config file.
    _D2C.merge_from_file(_A.d2_config)
    _D2C.merge_from_list(_A.d2_config_override)

    # Set some config parameters from args.
    _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers
    _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every
    _D2C.OUTPUT_DIR = _A.serialization_dir

    # Set ResNet depth to override in Detectron2's config.
    _D2C.MODEL.RESNETS.DEPTH = int(
        re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1)
        if "torchvision" in _C.MODEL.VISUAL.NAME
        else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1)
        if "detectron2" in _C.MODEL.VISUAL.NAME
        else 0
    )
    return _D2C


class DownstreamTrainer(DefaultTrainer):
    r"""
    Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.

    Parameters
    ----------
    cfg: detectron2.config.CfgNode
        Detectron2 config object containing all config params.
    weights: Union[str, Dict[str, Any]]
        Weights to load in the initialized model. If ``str``, then we assume path
        to a checkpoint, or if a ``dict``, we assume a state dict. This will be
        an ``str`` only if we resume training from a Detectron2 checkpoint.
    """

    def __init__(self, cfg, weights: Union[str, Dict[str, Any]]):

        super().__init__(cfg)

        # Load pre-trained weights before wrapping to DDP because `ApexDDP` has
        # some weird issue with `DetectionCheckpointer`.
        # fmt: off
        if isinstance(weights, str):
            # weights are ``str`` means ImageNet init or resume training.
            self.start_iter = (
                DetectionCheckpointer(
                    self._trainer.model,
                    optimizer=self._trainer.optimizer,
                    scheduler=self.scheduler
                ).resume_or_load(weights, resume=True).get("iteration", -1) + 1
            )
        elif isinstance(weights, dict):
            # weights are a state dict means our pretrain init.
            DetectionCheckpointer(self._trainer.model)._load_model(weights)
        # fmt: on

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type
        if evaluator_type == "pascal_voc":
            return PascalVOCDetectionEvaluator(dataset_name)
        elif evaluator_type == "coco":
            return COCOEvaluator(dataset_name, cfg, True, output_folder)
        elif evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, cfg, True, output_folder)

    def test(self, cfg=None, model=None, evaluators=None):
        r"""Evaluate the model and log results to stdout and tensorboard."""
        cfg = cfg or self.cfg
        model = model or self.model

        tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
        results = super().test(cfg, model)
        flat_results = d2.evaluation.testing.flatten_results_dict(results)
        for k, v in flat_results.items():
            tensorboard_writer.add_scalar(k, v, self.start_iter)


def main(_A: argparse.Namespace):

    # Get the current device as set for current distributed process.
    # Check `launch` function in `virtex.utils.distributed` module.
    device = torch.cuda.current_device()

    # Local process group is needed for detectron2.
    pg = list(range(dist.get_world_size()))
    d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)

    # Create a config object (this will be immutable) and perform common setup
    # such as logging and setting up serialization directory.
    if _A.weight_init == "imagenet":
        _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
    _C = Config(_A.config, _A.config_override)

    # We use `default_setup` from detectron2 to do some common setup, such as
    # logging, setting up serialization etc. For more info, look into source.
    _D2C = build_detectron2_config(_C, _A)
    default_setup(_D2C, _A)

    # Prepare weights to pass in instantiation call of trainer.
    if _A.weight_init in {"virtex", "torchvision"}:
        if _A.resume:
            # If resuming training, let detectron2 load weights by providing path.
            model = None
            weights = _A.checkpoint_path
        else:
            # Load backbone weights from VirTex pretrained checkpoint.
            model = PretrainingModelFactory.from_config(_C)
            if _A.weight_init == "virtex":
                CheckpointManager(model=model).load(_A.checkpoint_path)
            else:
                model.visual.cnn.load_state_dict(
                    torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
                    strict=False,
                )
            weights = model.visual.detectron2_backbone_state_dict()
    else:
        # If random or imagenet init, just load weights after initializing model.
        model = PretrainingModelFactory.from_config(_C)
        weights = model.visual.detectron2_backbone_state_dict()

    # Back up pretrain config and model checkpoint (if provided).
    _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
    if _A.weight_init == "virtex" and not _A.resume:
        torch.save(
            model.state_dict(),
            os.path.join(_A.serialization_dir, "pretrain_model.pth"),
        )

    del model
    trainer = DownstreamTrainer(_D2C, weights)
    trainer.test() if _A.eval_only else trainer.train()


if __name__ == "__main__":
    _A = parser.parse_args()

    # This will launch `main` and set appropriate CUDA device (GPU ID) as
    # per process (accessed in the beginning of `main`).
    dist.launch(
        main,
        num_machines=_A.num_machines,
        num_gpus_per_machine=_A.num_gpus_per_machine,
        machine_rank=_A.machine_rank,
        dist_url=_A.dist_url,
        args=(_A, ),
    )