RockeyCoss
add code files”
51f6859
raw
history blame
681 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner.hooks import HOOKS, Hook
@HOOKS.register_module()
class CheckInvalidLossHook(Hook):
"""Check invalid loss hook.
This hook will regularly check whether the loss is valid
during training.
Args:
interval (int): Checking interval (every k iterations).
Default: 50.
"""
def __init__(self, interval=50):
self.interval = interval
def after_train_iter(self, runner):
if self.every_n_iters(runner, self.interval):
assert torch.isfinite(runner.outputs['loss']), \
runner.logger.info('loss become infinite or NaN!')