RockeyCoss
add code files”
51f6859
raw
history blame
23.9 kB
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
import os.path as osp
import sys
import warnings
import mmcv
import numpy as np
import pycocotools.mask as mask_util
from mmcv.runner import HOOKS
from mmcv.runner.dist_utils import master_only
from mmcv.runner.hooks.checkpoint import CheckpointHook
from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
from mmcv.utils import digit_version
from mmdet.core import DistEvalHook, EvalHook
from mmdet.core.mask.structures import polygon_to_bitmap
@HOOKS.register_module()
class MMDetWandbHook(WandbLoggerHook):
"""Enhanced Wandb logger hook for MMDetection.
Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
only automatically log all the metrics but also log the following extra
information - saves model checkpoints as W&B Artifact, and
logs model prediction as interactive W&B Tables.
- Metrics: The MMDetWandbHook will automatically log training
and validation metrics along with system metrics (CPU/GPU).
- Checkpointing: If `log_checkpoint` is True, the checkpoint saved at
every checkpoint interval will be saved as W&B Artifacts.
This depends on the : class:`mmcv.runner.CheckpointHook` whose priority
is higher than this hook. Please refer to
https://docs.wandb.ai/guides/artifacts/model-versioning
to learn more about model versioning with W&B Artifacts.
- Checkpoint Metadata: If evaluation results are available for a given
checkpoint artifact, it will have a metadata associated with it.
The metadata contains the evaluation metrics computed on validation
data with that checkpoint along with the current epoch. It depends
on `EvalHook` whose priority is more than MMDetWandbHook.
- Evaluation: At every evaluation interval, the `MMDetWandbHook` logs the
model prediction as interactive W&B Tables. The number of samples
logged is given by `num_eval_images`. Currently, the `MMDetWandbHook`
logs the predicted bounding boxes along with the ground truth at every
evaluation interval. This depends on the `EvalHook` whose priority is
more than `MMDetWandbHook`. Also note that the data is just logged once
and subsequent evaluation tables uses reference to the logged data
to save memory usage. Please refer to
https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
For more details check out W&B's MMDetection docs:
https://docs.wandb.ai/guides/integrations/mmdetection
```
Example:
log_config = dict(
...
hooks=[
...,
dict(type='MMDetWandbHook',
init_kwargs={
'entity': "YOUR_ENTITY",
'project': "YOUR_PROJECT_NAME"
},
interval=50,
log_checkpoint=True,
log_checkpoint_metadata=True,
num_eval_images=100,
bbox_score_thr=0.3)
])
```
Args:
init_kwargs (dict): A dict passed to wandb.init to initialize
a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
for possible key-value pairs.
interval (int): Logging interval (every k iterations). Defaults to 50.
log_checkpoint (bool): Save the checkpoint at every checkpoint interval
as W&B Artifacts. Use this for model versioning where each version
is a checkpoint. Defaults to False.
log_checkpoint_metadata (bool): Log the evaluation metrics computed
on the validation data with the checkpoint, along with current
epoch as a metadata to that checkpoint.
Defaults to True.
num_eval_images (int): The number of validation images to be logged.
If zero, the evaluation won't be logged. Defaults to 100.
bbox_score_thr (float): Threshold for bounding box scores.
Defaults to 0.3.
"""
def __init__(self,
init_kwargs=None,
interval=50,
log_checkpoint=False,
log_checkpoint_metadata=False,
num_eval_images=100,
bbox_score_thr=0.3,
**kwargs):
super(MMDetWandbHook, self).__init__(init_kwargs, interval, **kwargs)
self.log_checkpoint = log_checkpoint
self.log_checkpoint_metadata = (
log_checkpoint and log_checkpoint_metadata)
self.num_eval_images = num_eval_images
self.bbox_score_thr = bbox_score_thr
self.log_evaluation = (num_eval_images > 0)
self.ckpt_hook: CheckpointHook = None
self.eval_hook: EvalHook = None
def import_wandb(self):
try:
import wandb
from wandb import init # noqa
# Fix ResourceWarning when calling wandb.log in wandb v0.12.10.
# https://github.com/wandb/client/issues/2837
if digit_version(wandb.__version__) < digit_version('0.12.10'):
warnings.warn(
f'The current wandb {wandb.__version__} is '
f'lower than v0.12.10 will cause ResourceWarning '
f'when calling wandb.log, Please run '
f'"pip install --upgrade wandb"')
except ImportError:
raise ImportError(
'Please run "pip install "wandb>=0.12.10"" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner):
super(MMDetWandbHook, self).before_run(runner)
# Save and Log config.
if runner.meta is not None and runner.meta.get('exp_name',
None) is not None:
src_cfg_path = osp.join(runner.work_dir,
runner.meta.get('exp_name', None))
if osp.exists(src_cfg_path):
self.wandb.save(src_cfg_path, base_path=runner.work_dir)
self._update_wandb_config(runner)
else:
runner.logger.warning('No meta information found in the runner. ')
# Inspect CheckpointHook and EvalHook
for hook in runner.hooks:
if isinstance(hook, CheckpointHook):
self.ckpt_hook = hook
if isinstance(hook, (EvalHook, DistEvalHook)):
self.eval_hook = hook
# Check conditions to log checkpoint
if self.log_checkpoint:
if self.ckpt_hook is None:
self.log_checkpoint = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log checkpoint in MMDetWandbHook, `CheckpointHook` is'
'required, please check hooks in the runner.')
else:
self.ckpt_interval = self.ckpt_hook.interval
# Check conditions to log evaluation
if self.log_evaluation or self.log_checkpoint_metadata:
if self.eval_hook is None:
self.log_evaluation = False
self.log_checkpoint_metadata = False
runner.logger.warning(
'To log evaluation or checkpoint metadata in '
'MMDetWandbHook, `EvalHook` or `DistEvalHook` in mmdet '
'is required, please check whether the validation '
'is enabled.')
else:
self.eval_interval = self.eval_hook.interval
self.val_dataset = self.eval_hook.dataloader.dataset
# Determine the number of samples to be logged.
if self.num_eval_images > len(self.val_dataset):
self.num_eval_images = len(self.val_dataset)
runner.logger.warning(
f'The num_eval_images ({self.num_eval_images}) is '
'greater than the total number of validation samples '
f'({len(self.val_dataset)}). The complete validation '
'dataset will be logged.')
# Check conditions to log checkpoint metadata
if self.log_checkpoint_metadata:
assert self.ckpt_interval % self.eval_interval == 0, \
'To log checkpoint metadata in MMDetWandbHook, the interval ' \
f'of checkpoint saving ({self.ckpt_interval}) should be ' \
'divisible by the interval of evaluation ' \
f'({self.eval_interval}).'
# Initialize evaluation table
if self.log_evaluation:
# Initialize data table
self._init_data_table()
# Add data to the data table
self._add_ground_truth(runner)
# Log ground truth data
self._log_data_table()
@master_only
def after_train_epoch(self, runner):
super(MMDetWandbHook, self).after_train_epoch(runner)
if not self.by_epoch:
return
# Log checkpoint and metadata.
if (self.log_checkpoint
and self.every_n_epochs(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'epoch': runner.epoch + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'epoch_{runner.epoch + 1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'epoch_{runner.epoch + 1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Log predictions
self._log_predictions(results)
# Log the table
self._log_eval_table(runner.epoch + 1)
# for the reason of this double-layered structure, refer to
# https://github.com/open-mmlab/mmdetection/issues/8145#issuecomment-1345343076
def after_train_iter(self, runner):
if self.get_mode(runner) == 'train':
# An ugly patch. The iter-based eval hook will call the
# `after_train_iter` method of all logger hooks before evaluation.
# Use this trick to skip that call.
# Don't call super method at first, it will clear the log_buffer
return super(MMDetWandbHook, self).after_train_iter(runner)
else:
super(MMDetWandbHook, self).after_train_iter(runner)
self._after_train_iter(runner)
@master_only
def _after_train_iter(self, runner):
if self.by_epoch:
return
# Save checkpoint and metadata
if (self.log_checkpoint
and self.every_n_iters(runner, self.ckpt_interval)
or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
if self.log_checkpoint_metadata and self.eval_hook:
metadata = {
'iter': runner.iter + 1,
**self._get_eval_results()
}
else:
metadata = None
aliases = [f'iter_{runner.iter + 1}', 'latest']
model_path = osp.join(self.ckpt_hook.out_dir,
f'iter_{runner.iter + 1}.pth')
self._log_ckpt_as_artifact(model_path, aliases, metadata)
# Save prediction table
if self.log_evaluation and self.eval_hook._should_evaluate(runner):
results = self.eval_hook.latest_results
# Initialize evaluation table
self._init_pred_table()
# Log predictions
self._log_predictions(results)
# Log the table
self._log_eval_table(runner.iter + 1)
@master_only
def after_run(self, runner):
self.wandb.finish()
def _update_wandb_config(self, runner):
"""Update wandb config."""
# Import the config file.
sys.path.append(runner.work_dir)
config_filename = runner.meta['exp_name'][:-3]
configs = importlib.import_module(config_filename)
# Prepare a nested dict of config variables.
config_keys = [key for key in dir(configs) if not key.startswith('__')]
config_dict = {key: getattr(configs, key) for key in config_keys}
# Update the W&B config.
self.wandb.config.update(config_dict)
def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
"""Log model checkpoint as W&B Artifact.
Args:
model_path (str): Path of the checkpoint to log.
aliases (list): List of the aliases associated with this artifact.
metadata (dict, optional): Metadata associated with this artifact.
"""
model_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
model_artifact.add_file(model_path)
self.wandb.log_artifact(model_artifact, aliases=aliases)
def _get_eval_results(self):
"""Get model evaluation results."""
results = self.eval_hook.latest_results
eval_results = self.val_dataset.evaluate(
results, logger='silent', **self.eval_hook.eval_kwargs)
return eval_results
def _init_data_table(self):
"""Initialize the W&B Tables for validation data."""
columns = ['image_name', 'image']
self.data_table = self.wandb.Table(columns=columns)
def _init_pred_table(self):
"""Initialize the W&B Tables for model evaluation."""
columns = ['image_name', 'ground_truth', 'prediction']
self.eval_table = self.wandb.Table(columns=columns)
def _add_ground_truth(self, runner):
# Get image loading pipeline
from mmdet.datasets.pipelines import LoadImageFromFile
img_loader = None
for t in self.val_dataset.pipeline.transforms:
if isinstance(t, LoadImageFromFile):
img_loader = t
if img_loader is None:
self.log_evaluation = False
runner.logger.warning(
'LoadImageFromFile is required to add images '
'to W&B Tables.')
return
# Select the images to be logged.
self.eval_image_indexs = np.arange(len(self.val_dataset))
# Set seed so that same validation set is logged each time.
np.random.seed(42)
np.random.shuffle(self.eval_image_indexs)
self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
CLASSES = self.val_dataset.CLASSES
self.class_id_to_label = {
id + 1: name
for id, name in enumerate(CLASSES)
}
self.class_set = self.wandb.Classes([{
'id': id,
'name': name
} for id, name in self.class_id_to_label.items()])
img_prefix = self.val_dataset.img_prefix
for idx in self.eval_image_indexs:
img_info = self.val_dataset.data_infos[idx]
image_name = img_info.get('filename', f'img_{idx}')
img_height, img_width = img_info['height'], img_info['width']
img_meta = img_loader(
dict(img_info=img_info, img_prefix=img_prefix))
# Get image and convert from BGR to RGB
image = mmcv.bgr2rgb(img_meta['img'])
data_ann = self.val_dataset.get_ann_info(idx)
bboxes = data_ann['bboxes']
labels = data_ann['labels']
masks = data_ann.get('masks', None)
# Get dict of bounding boxes to be logged.
assert len(bboxes) == len(labels)
wandb_boxes = self._get_wandb_bboxes(bboxes, labels)
# Get dict of masks to be logged.
if masks is not None:
wandb_masks = self._get_wandb_masks(
masks,
labels,
is_poly_mask=True,
height=img_height,
width=img_width)
else:
wandb_masks = None
# TODO: Panoramic segmentation visualization.
# Log a row to the data table.
self.data_table.add_data(
image_name,
self.wandb.Image(
image,
boxes=wandb_boxes,
masks=wandb_masks,
classes=self.class_set))
def _log_predictions(self, results):
table_idxs = self.data_table_ref.get_index()
assert len(table_idxs) == len(self.eval_image_indexs)
for ndx, eval_image_index in enumerate(self.eval_image_indexs):
# Get the result
result = results[eval_image_index]
if isinstance(result, tuple):
bbox_result, segm_result = result
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
else:
bbox_result, segm_result = result, None
assert len(bbox_result) == len(self.class_id_to_label)
# Get labels
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# Get segmentation mask if available.
segms = None
if segm_result is not None and len(labels) > 0:
segms = mmcv.concat_list(segm_result)
segms = mask_util.decode(segms)
segms = segms.transpose(2, 0, 1)
assert len(segms) == len(labels)
# TODO: Panoramic segmentation visualization.
# Remove bounding boxes and masks with score lower than threshold.
if self.bbox_score_thr > 0:
assert bboxes is not None and bboxes.shape[1] == 5
scores = bboxes[:, -1]
inds = scores > self.bbox_score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]
if segms is not None:
segms = segms[inds, ...]
# Get dict of bounding boxes to be logged.
wandb_boxes = self._get_wandb_bboxes(bboxes, labels, log_gt=False)
# Get dict of masks to be logged.
if segms is not None:
wandb_masks = self._get_wandb_masks(segms, labels)
else:
wandb_masks = None
# Log a row to the eval table.
self.eval_table.add_data(
self.data_table_ref.data[ndx][0],
self.data_table_ref.data[ndx][1],
self.wandb.Image(
self.data_table_ref.data[ndx][1],
boxes=wandb_boxes,
masks=wandb_masks,
classes=self.class_set))
def _get_wandb_bboxes(self, bboxes, labels, log_gt=True):
"""Get list of structured dict for logging bounding boxes to W&B.
Args:
bboxes (list): List of bounding box coordinates in
(minX, minY, maxX, maxY) format.
labels (int): List of label ids.
log_gt (bool): Whether to log ground truth or prediction boxes.
Returns:
Dictionary of bounding boxes to be logged.
"""
wandb_boxes = {}
box_data = []
for bbox, label in zip(bboxes, labels):
if not isinstance(label, int):
label = int(label)
label = label + 1
if len(bbox) == 5:
confidence = float(bbox[4])
class_name = self.class_id_to_label[label]
box_caption = f'{class_name} {confidence:.2f}'
else:
box_caption = str(self.class_id_to_label[label])
position = dict(
minX=int(bbox[0]),
minY=int(bbox[1]),
maxX=int(bbox[2]),
maxY=int(bbox[3]))
box_data.append({
'position': position,
'class_id': label,
'box_caption': box_caption,
'domain': 'pixel'
})
wandb_bbox_dict = {
'box_data': box_data,
'class_labels': self.class_id_to_label
}
if log_gt:
wandb_boxes['ground_truth'] = wandb_bbox_dict
else:
wandb_boxes['predictions'] = wandb_bbox_dict
return wandb_boxes
def _get_wandb_masks(self,
masks,
labels,
is_poly_mask=False,
height=None,
width=None):
"""Get list of structured dict for logging masks to W&B.
Args:
masks (list): List of masks.
labels (int): List of label ids.
is_poly_mask (bool): Whether the mask is polygonal or not.
This is true for CocoDataset.
height (int): Height of the image.
width (int): Width of the image.
Returns:
Dictionary of masks to be logged.
"""
mask_label_dict = dict()
for mask, label in zip(masks, labels):
label = label + 1
# Get bitmap mask from polygon.
if is_poly_mask:
if height is not None and width is not None:
mask = polygon_to_bitmap(mask, height, width)
# Create composite masks for each class.
if label not in mask_label_dict.keys():
mask_label_dict[label] = mask
else:
mask_label_dict[label] = np.logical_or(mask_label_dict[label],
mask)
wandb_masks = dict()
for key, value in mask_label_dict.items():
# Create mask for that class.
value = value.astype(np.uint8)
value[value > 0] = key
# Create dict of masks for logging.
class_name = self.class_id_to_label[key]
wandb_masks[class_name] = {
'mask_data': value,
'class_labels': self.class_id_to_label
}
return wandb_masks
def _log_data_table(self):
"""Log the W&B Tables for validation data as artifact and calls
`use_artifact` on it so that the evaluation table can use the reference
of already uploaded images.
This allows the data to be uploaded just once.
"""
data_artifact = self.wandb.Artifact('val', type='dataset')
data_artifact.add(self.data_table, 'val_data')
if not self.wandb.run.offline:
self.wandb.run.use_artifact(data_artifact)
data_artifact.wait()
self.data_table_ref = data_artifact.get('val_data')
else:
self.data_table_ref = self.data_table
def _log_eval_table(self, idx):
"""Log the W&B Tables for model evaluation.
The table will be logged multiple times creating new version. Use this
to compare models at different intervals interactively.
"""
pred_artifact = self.wandb.Artifact(
f'run_{self.wandb.run.id}_pred', type='evaluation')
pred_artifact.add(self.eval_table, 'eval_data')
if self.by_epoch:
aliases = ['latest', f'epoch_{idx}']
else:
aliases = ['latest', f'iter_{idx}']
self.wandb.run.log_artifact(pred_artifact, aliases=aliases)