|
from abc import ABC, abstractmethod |
|
import logging |
|
from os import path |
|
import os |
|
from threading import Thread |
|
from time import sleep, time |
|
from typing import Callable, Optional |
|
import uuid |
|
import torch.multiprocessing as mp |
|
|
|
import torch |
|
from ding.data.storage.file import FileModelStorage |
|
from ding.data.storage.storage import Storage |
|
from ding.framework import Supervisor |
|
from ding.framework.supervisor import ChildType, SendPayload |
|
|
|
|
|
class ModelWorker(): |
|
|
|
def __init__(self, model: torch.nn.Module) -> None: |
|
self._model = model |
|
|
|
def save(self, storage: Storage) -> Storage: |
|
storage.save(self._model.state_dict()) |
|
return storage |
|
|
|
|
|
class ModelLoader(Supervisor, ABC): |
|
|
|
def __init__(self, model: torch.nn.Module) -> None: |
|
""" |
|
Overview: |
|
Save and send models asynchronously and load them synchronously. |
|
Arguments: |
|
- model (:obj:`torch.nn.Module`): Torch module. |
|
""" |
|
if next(model.parameters()).is_cuda: |
|
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn")) |
|
else: |
|
super().__init__(type_=ChildType.PROCESS) |
|
self._model = model |
|
self._send_callback_loop = None |
|
self._send_callbacks = {} |
|
self._model_worker = ModelWorker(self._model) |
|
|
|
def start(self): |
|
if not self._running: |
|
self._model.share_memory() |
|
self.register(self._model_worker) |
|
self.start_link() |
|
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True) |
|
self._send_callback_loop.start() |
|
|
|
def shutdown(self, timeout: Optional[float] = None) -> None: |
|
super().shutdown(timeout) |
|
self._send_callback_loop = None |
|
self._send_callbacks = {} |
|
|
|
def _loop_send_callback(self): |
|
while True: |
|
payload = self.recv(ignore_err=True) |
|
if payload.err: |
|
logging.warning("Got error when loading data: {}".format(payload.err)) |
|
if payload.req_id in self._send_callbacks: |
|
del self._send_callbacks[payload.req_id] |
|
else: |
|
if payload.req_id in self._send_callbacks: |
|
callback = self._send_callbacks.pop(payload.req_id) |
|
callback(payload.data) |
|
|
|
def load(self, storage: Storage) -> object: |
|
""" |
|
Overview: |
|
Load model synchronously. |
|
Arguments: |
|
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage. |
|
Returns: |
|
- object (:obj:): The loaded model. |
|
""" |
|
return storage.load() |
|
|
|
@abstractmethod |
|
def save(self, callback: Callable) -> Storage: |
|
""" |
|
Overview: |
|
Save model asynchronously. |
|
Arguments: |
|
- callback (:obj:`Callable`): The callback function after saving model. |
|
Returns: |
|
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class FileModelLoader(ModelLoader): |
|
|
|
def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None: |
|
""" |
|
Overview: |
|
Model loader using files as storage media. |
|
Arguments: |
|
- model (:obj:`torch.nn.Module`): Torch module. |
|
- dirname (:obj:`str`): The directory for saving files. |
|
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \ |
|
files that do not time out when the process is stopped are not cleaned up \ |
|
(to avoid errors when other processes read the file), so you may need to \ |
|
clean up the remaining files manually |
|
""" |
|
super().__init__(model) |
|
self._dirname = dirname |
|
self._ttl = ttl |
|
self._files = [] |
|
self._cleanup_thread = None |
|
|
|
def _start_cleanup(self): |
|
""" |
|
Overview: |
|
Start a cleanup thread to clean up files that are taking up too much time on the disk. |
|
""" |
|
if self._cleanup_thread is None: |
|
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True) |
|
self._cleanup_thread.start() |
|
|
|
def shutdown(self, timeout: Optional[float] = None) -> None: |
|
super().shutdown(timeout) |
|
self._cleanup_thread = None |
|
|
|
def _loop_cleanup(self): |
|
while True: |
|
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl: |
|
sleep(1) |
|
continue |
|
_, file_path = self._files.pop(0) |
|
if path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
def save(self, callback: Callable) -> FileModelStorage: |
|
if not self._running: |
|
logging.warning("Please start model loader before saving model.") |
|
return |
|
if not path.exists(self._dirname): |
|
os.mkdir(self._dirname) |
|
file_path = "model_{}.pth.tar".format(uuid.uuid1()) |
|
file_path = path.join(self._dirname, file_path) |
|
model_storage = FileModelStorage(file_path) |
|
payload = SendPayload(proc_id=0, method="save", args=[model_storage]) |
|
self.send(payload) |
|
|
|
def clean_callback(storage: Storage): |
|
self._files.append([time(), file_path]) |
|
callback(storage) |
|
|
|
self._send_callbacks[payload.req_id] = clean_callback |
|
self._start_cleanup() |
|
return model_storage |
|
|