File size: 5,501 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from abc import ABC, abstractmethod
import logging
from os import path
import os
from threading import Thread
from time import sleep, time
from typing import Callable, Optional
import uuid
import torch.multiprocessing as mp

import torch
from ding.data.storage.file import FileModelStorage
from ding.data.storage.storage import Storage
from ding.framework import Supervisor
from ding.framework.supervisor import ChildType, SendPayload


class ModelWorker():

    def __init__(self, model: torch.nn.Module) -> None:
        self._model = model

    def save(self, storage: Storage) -> Storage:
        storage.save(self._model.state_dict())
        return storage


class ModelLoader(Supervisor, ABC):

    def __init__(self, model: torch.nn.Module) -> None:
        """
        Overview:
            Save and send models asynchronously and load them synchronously.
        Arguments:
            - model (:obj:`torch.nn.Module`): Torch module.
        """
        if next(model.parameters()).is_cuda:
            super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
        else:
            super().__init__(type_=ChildType.PROCESS)
        self._model = model
        self._send_callback_loop = None
        self._send_callbacks = {}
        self._model_worker = ModelWorker(self._model)

    def start(self):
        if not self._running:
            self._model.share_memory()
            self.register(self._model_worker)
            self.start_link()
            self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
            self._send_callback_loop.start()

    def shutdown(self, timeout: Optional[float] = None) -> None:
        super().shutdown(timeout)
        self._send_callback_loop = None
        self._send_callbacks = {}

    def _loop_send_callback(self):
        while True:
            payload = self.recv(ignore_err=True)
            if payload.err:
                logging.warning("Got error when loading data: {}".format(payload.err))
                if payload.req_id in self._send_callbacks:
                    del self._send_callbacks[payload.req_id]
            else:
                if payload.req_id in self._send_callbacks:
                    callback = self._send_callbacks.pop(payload.req_id)
                    callback(payload.data)

    def load(self, storage: Storage) -> object:
        """
        Overview:
            Load model synchronously.
        Arguments:
            - storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
        Returns:
            - object (:obj:): The loaded model.
        """
        return storage.load()

    @abstractmethod
    def save(self, callback: Callable) -> Storage:
        """
        Overview:
            Save model asynchronously.
        Arguments:
            - callback (:obj:`Callable`): The callback function after saving model.
        Returns:
            - storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
        """
        raise NotImplementedError


class FileModelLoader(ModelLoader):

    def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
        """
        Overview:
            Model loader using files as storage media.
        Arguments:
            - model (:obj:`torch.nn.Module`): Torch module.
            - dirname (:obj:`str`): The directory for saving files.
            - ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
                files that do not time out when the process is stopped are not cleaned up \
                (to avoid errors when other processes read the file), so you may need to \
                clean up the remaining files manually
        """
        super().__init__(model)
        self._dirname = dirname
        self._ttl = ttl
        self._files = []
        self._cleanup_thread = None

    def _start_cleanup(self):
        """
        Overview:
            Start a cleanup thread to clean up files that are taking up too much time on the disk.
        """
        if self._cleanup_thread is None:
            self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
            self._cleanup_thread.start()

    def shutdown(self, timeout: Optional[float] = None) -> None:
        super().shutdown(timeout)
        self._cleanup_thread = None

    def _loop_cleanup(self):
        while True:
            if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
                sleep(1)
                continue
            _, file_path = self._files.pop(0)
            if path.exists(file_path):
                os.remove(file_path)

    def save(self, callback: Callable) -> FileModelStorage:
        if not self._running:
            logging.warning("Please start model loader before saving model.")
            return
        if not path.exists(self._dirname):
            os.mkdir(self._dirname)
        file_path = "model_{}.pth.tar".format(uuid.uuid1())
        file_path = path.join(self._dirname, file_path)
        model_storage = FileModelStorage(file_path)
        payload = SendPayload(proc_id=0, method="save", args=[model_storage])
        self.send(payload)

        def clean_callback(storage: Storage):
            self._files.append([time(), file_path])
            callback(storage)

        self._send_callbacks[payload.req_id] = clean_callback
        self._start_cleanup()
        return model_storage