Spaces:
Runtime error
Runtime error
File size: 9,452 Bytes
a5f8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
"""
Finetune a pre-trained model on a downstream task, one of those available in
Detectron2.
Supported downstream:
- LVIS Instance Segmentation
- COCO Instance Segmentation
- Pascal VOC 2007+12 Object Detection
Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
Thanks to the developers of Detectron2!
"""
import argparse
import os
import re
from typing import Any, Dict, Union
import torch
from torch.utils.tensorboard import SummaryWriter
import detectron2 as d2
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultTrainer, default_setup
from detectron2.evaluation import (
LVISEvaluator,
PascalVOCDetectionEvaluator,
COCOEvaluator,
)
from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser
import virtex.utils.distributed as dist
# fmt: off
parser = common_parser(
description="Train object detectors from pretrained visual backbone."
)
parser.add_argument(
"--d2-config", required=True,
help="Path to a detectron2 config for downstream task finetuning."
)
parser.add_argument(
"--d2-config-override", nargs="*", default=[],
help="""Key-value pairs from Detectron2 config to override from file.
Some keys will be ignored because they are set from other args:
[DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,
TEST.EVAL_PERIOD, OUTPUT_DIR]""",
)
parser.add_argument_group("Checkpointing and Logging")
parser.add_argument(
"--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
default="virtex", help="""How to initialize weights:
1. 'random' initializes all weights randomly
2. 'imagenet' initializes backbone weights from torchvision model zoo
3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
- with 'torchvision', state dict would be from PyTorch's training
script.
- with 'virtex' it should be for our full pretrained model."""
)
parser.add_argument(
"--checkpoint-path",
help="Path to load checkpoint and run downstream task evaluation."
)
parser.add_argument(
"--resume", action="store_true", help="""Specify this flag when resuming
training from a checkpoint saved by Detectron2."""
)
parser.add_argument(
"--eval-only", action="store_true",
help="Skip training and evaluate checkpoint provided at --checkpoint-path.",
)
parser.add_argument(
"--checkpoint-every", type=int, default=5000,
help="Serialize model to a checkpoint after every these many iterations.",
)
# fmt: on
@ROI_HEADS_REGISTRY.register()
class Res5ROIHeadsExtraNorm(Res5ROIHeads):
r"""
ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN
C4/DC5 backbones for VOC detection.
"""
def _build_res5_block(self, cfg):
seq, out_channels = super()._build_res5_block(cfg)
norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)
seq.add_module("norm", norm)
return seq, out_channels
def build_detectron2_config(_C: Config, _A: argparse.Namespace):
r"""Build detectron2 config based on our pre-training config and args."""
_D2C = d2.config.get_cfg()
# Override some default values based on our config file.
_D2C.merge_from_file(_A.d2_config)
_D2C.merge_from_list(_A.d2_config_override)
# Set some config parameters from args.
_D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers
_D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every
_D2C.OUTPUT_DIR = _A.serialization_dir
# Set ResNet depth to override in Detectron2's config.
_D2C.MODEL.RESNETS.DEPTH = int(
re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1)
if "torchvision" in _C.MODEL.VISUAL.NAME
else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1)
if "detectron2" in _C.MODEL.VISUAL.NAME
else 0
)
return _D2C
class DownstreamTrainer(DefaultTrainer):
r"""
Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.
Parameters
----------
cfg: detectron2.config.CfgNode
Detectron2 config object containing all config params.
weights: Union[str, Dict[str, Any]]
Weights to load in the initialized model. If ``str``, then we assume path
to a checkpoint, or if a ``dict``, we assume a state dict. This will be
an ``str`` only if we resume training from a Detectron2 checkpoint.
"""
def __init__(self, cfg, weights: Union[str, Dict[str, Any]]):
super().__init__(cfg)
# Load pre-trained weights before wrapping to DDP because `ApexDDP` has
# some weird issue with `DetectionCheckpointer`.
# fmt: off
if isinstance(weights, str):
# weights are ``str`` means ImageNet init or resume training.
self.start_iter = (
DetectionCheckpointer(
self._trainer.model,
optimizer=self._trainer.optimizer,
scheduler=self.scheduler
).resume_or_load(weights, resume=True).get("iteration", -1) + 1
)
elif isinstance(weights, dict):
# weights are a state dict means our pretrain init.
DetectionCheckpointer(self._trainer.model)._load_model(weights)
# fmt: on
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
elif evaluator_type == "coco":
return COCOEvaluator(dataset_name, cfg, True, output_folder)
elif evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
def test(self, cfg=None, model=None, evaluators=None):
r"""Evaluate the model and log results to stdout and tensorboard."""
cfg = cfg or self.cfg
model = model or self.model
tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
results = super().test(cfg, model)
flat_results = d2.evaluation.testing.flatten_results_dict(results)
for k, v in flat_results.items():
tensorboard_writer.add_scalar(k, v, self.start_iter)
def main(_A: argparse.Namespace):
# Get the current device as set for current distributed process.
# Check `launch` function in `virtex.utils.distributed` module.
device = torch.cuda.current_device()
# Local process group is needed for detectron2.
pg = list(range(dist.get_world_size()))
d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)
# Create a config object (this will be immutable) and perform common setup
# such as logging and setting up serialization directory.
if _A.weight_init == "imagenet":
_A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
_C = Config(_A.config, _A.config_override)
# We use `default_setup` from detectron2 to do some common setup, such as
# logging, setting up serialization etc. For more info, look into source.
_D2C = build_detectron2_config(_C, _A)
default_setup(_D2C, _A)
# Prepare weights to pass in instantiation call of trainer.
if _A.weight_init in {"virtex", "torchvision"}:
if _A.resume:
# If resuming training, let detectron2 load weights by providing path.
model = None
weights = _A.checkpoint_path
else:
# Load backbone weights from VirTex pretrained checkpoint.
model = PretrainingModelFactory.from_config(_C)
if _A.weight_init == "virtex":
CheckpointManager(model=model).load(_A.checkpoint_path)
else:
model.visual.cnn.load_state_dict(
torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
strict=False,
)
weights = model.visual.detectron2_backbone_state_dict()
else:
# If random or imagenet init, just load weights after initializing model.
model = PretrainingModelFactory.from_config(_C)
weights = model.visual.detectron2_backbone_state_dict()
# Back up pretrain config and model checkpoint (if provided).
_C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
if _A.weight_init == "virtex" and not _A.resume:
torch.save(
model.state_dict(),
os.path.join(_A.serialization_dir, "pretrain_model.pth"),
)
del model
trainer = DownstreamTrainer(_D2C, weights)
trainer.test() if _A.eval_only else trainer.train()
if __name__ == "__main__":
_A = parser.parse_args()
# This will launch `main` and set appropriate CUDA device (GPU ID) as
# per process (accessed in the beginning of `main`).
dist.launch(
main,
num_machines=_A.num_machines,
num_gpus_per_machine=_A.num_gpus_per_machine,
machine_rank=_A.machine_rank,
dist_url=_A.dist_url,
args=(_A, ),
)
|