Spaces:
Running
Running
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license | |
"""Callback utils.""" | |
import threading | |
class Callbacks: | |
"""" Handles all registered callbacks for YOLOv5 Hooks.""" | |
def __init__(self): | |
# Define the available callbacks | |
self._callbacks = { | |
"on_pretrain_routine_start": [], | |
"on_pretrain_routine_end": [], | |
"on_train_start": [], | |
"on_train_epoch_start": [], | |
"on_train_batch_start": [], | |
"optimizer_step": [], | |
"on_before_zero_grad": [], | |
"on_train_batch_end": [], | |
"on_train_epoch_end": [], | |
"on_val_start": [], | |
"on_val_batch_start": [], | |
"on_val_image_end": [], | |
"on_val_batch_end": [], | |
"on_val_end": [], | |
"on_fit_epoch_end": [], # fit = train + val | |
"on_model_save": [], | |
"on_train_end": [], | |
"on_params_update": [], | |
"teardown": [], | |
} | |
self.stop_training = False # set True to interrupt training | |
def register_action(self, hook, name="", callback=None): | |
""" | |
Register a new action to a callback hook. | |
Args: | |
hook: The callback hook name to register the action to | |
name: The name of the action for later reference | |
callback: The callback to fire | |
""" | |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |
assert callable(callback), f"callback '{callback}' is not callable" | |
self._callbacks[hook].append({"name": name, "callback": callback}) | |
def get_registered_actions(self, hook=None): | |
""" | |
" Returns all the registered actions by callback hook. | |
Args: | |
hook: The name of the hook to check, defaults to all | |
""" | |
return self._callbacks[hook] if hook else self._callbacks | |
def run(self, hook, *args, thread=False, **kwargs): | |
""" | |
Loop through the registered actions and fire all callbacks on main thread. | |
Args: | |
hook: The name of the hook to check, defaults to all | |
args: Arguments to receive from YOLOv5 | |
thread: (boolean) Run callbacks in daemon thread | |
kwargs: Keyword Arguments to receive from YOLOv5 | |
""" | |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |
for logger in self._callbacks[hook]: | |
if thread: | |
threading.Thread(target=logger["callback"], args=args, kwargs=kwargs, daemon=True).start() | |
else: | |
logger["callback"](*args, **kwargs) | |