MMOCR / mmocr /apis /train.py
tomofi's picture
Add application file
2366e36
raw
history blame
No virus
6.48 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner, get_dist_info)
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import build_dataloader, build_dataset
from mmocr import digit_version
from mmocr.apis.utils import (disable_text_recog_aug_test,
replace_image_to_tensor)
from mmocr.utils import get_root_logger
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
**dict(
seed=cfg.get('seed'),
drop_last=False,
dist=distributed,
num_gpus=len(cfg.gpu_ids)),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'samples_per_gpu',
'workers_per_gpu',
'shuffle',
'seed',
'drop_last',
'prefetch_num',
'pin_memory',
'persistent_workers',
] if k in cfg.data)
}
# step 2: cfg.data.train_dataloader has highest priority
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if 'runner' not in cfg:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.optimizer_config:
optimizer_config = OptimizerHook(**cfg.optimizer_config)
else:
optimizer_config = cfg.optimizer_config
# register hooks
runner.register_training_hooks(
cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config,
cfg.get('momentum_config', None),
custom_hooks_config=cfg.get('custom_hooks', None))
if distributed:
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if val_samples_per_gpu > 1:
# Support batch_size > 1 in test for text recognition
# by disable MultiRotateAugOCR since it is useless for most case
cfg = disable_text_recog_aug_test(cfg)
cfg = replace_image_to_tensor(cfg)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_loader_cfg = {
**loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('val_dataloader', {}),
**dict(samples_per_gpu=val_samples_per_gpu)
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed. If the seed is None, it will be replaced by a
random number, and then broadcasted to all processes.
Args:
seed (int, Optional): The seed.
device (str): The device where the seed will be put on.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()