gomoku / DI-engine /ding /data /model_loader.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
5.5 kB
from abc import ABC, abstractmethod
import logging
from os import path
import os
from threading import Thread
from time import sleep, time
from typing import Callable, Optional
import uuid
import torch.multiprocessing as mp
import torch
from ding.data.storage.file import FileModelStorage
from ding.data.storage.storage import Storage
from ding.framework import Supervisor
from ding.framework.supervisor import ChildType, SendPayload
class ModelWorker():
def __init__(self, model: torch.nn.Module) -> None:
self._model = model
def save(self, storage: Storage) -> Storage:
storage.save(self._model.state_dict())
return storage
class ModelLoader(Supervisor, ABC):
def __init__(self, model: torch.nn.Module) -> None:
"""
Overview:
Save and send models asynchronously and load them synchronously.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
"""
if next(model.parameters()).is_cuda:
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
else:
super().__init__(type_=ChildType.PROCESS)
self._model = model
self._send_callback_loop = None
self._send_callbacks = {}
self._model_worker = ModelWorker(self._model)
def start(self):
if not self._running:
self._model.share_memory()
self.register(self._model_worker)
self.start_link()
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
self._send_callback_loop.start()
def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._send_callback_loop = None
self._send_callbacks = {}
def _loop_send_callback(self):
while True:
payload = self.recv(ignore_err=True)
if payload.err:
logging.warning("Got error when loading data: {}".format(payload.err))
if payload.req_id in self._send_callbacks:
del self._send_callbacks[payload.req_id]
else:
if payload.req_id in self._send_callbacks:
callback = self._send_callbacks.pop(payload.req_id)
callback(payload.data)
def load(self, storage: Storage) -> object:
"""
Overview:
Load model synchronously.
Arguments:
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
Returns:
- object (:obj:): The loaded model.
"""
return storage.load()
@abstractmethod
def save(self, callback: Callable) -> Storage:
"""
Overview:
Save model asynchronously.
Arguments:
- callback (:obj:`Callable`): The callback function after saving model.
Returns:
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
"""
raise NotImplementedError
class FileModelLoader(ModelLoader):
def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
"""
Overview:
Model loader using files as storage media.
Arguments:
- model (:obj:`torch.nn.Module`): Torch module.
- dirname (:obj:`str`): The directory for saving files.
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
files that do not time out when the process is stopped are not cleaned up \
(to avoid errors when other processes read the file), so you may need to \
clean up the remaining files manually
"""
super().__init__(model)
self._dirname = dirname
self._ttl = ttl
self._files = []
self._cleanup_thread = None
def _start_cleanup(self):
"""
Overview:
Start a cleanup thread to clean up files that are taking up too much time on the disk.
"""
if self._cleanup_thread is None:
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
self._cleanup_thread.start()
def shutdown(self, timeout: Optional[float] = None) -> None:
super().shutdown(timeout)
self._cleanup_thread = None
def _loop_cleanup(self):
while True:
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
sleep(1)
continue
_, file_path = self._files.pop(0)
if path.exists(file_path):
os.remove(file_path)
def save(self, callback: Callable) -> FileModelStorage:
if not self._running:
logging.warning("Please start model loader before saving model.")
return
if not path.exists(self._dirname):
os.mkdir(self._dirname)
file_path = "model_{}.pth.tar".format(uuid.uuid1())
file_path = path.join(self._dirname, file_path)
model_storage = FileModelStorage(file_path)
payload = SendPayload(proc_id=0, method="save", args=[model_storage])
self.send(payload)
def clean_callback(storage: Storage):
self._files.append([time(), file_path])
callback(storage)
self._send_callbacks[payload.req_id] = clean_callback
self._start_cleanup()
return model_storage