|
from abc import ABC, abstractmethod |
|
import functools |
|
import torch.multiprocessing as mp |
|
from multiprocessing.context import BaseContext |
|
import threading |
|
import queue |
|
import platform |
|
import traceback |
|
import uuid |
|
import time |
|
from ditk import logging |
|
from dataclasses import dataclass, field |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
from enum import Enum |
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
def get_mp_ctx() -> BaseContext: |
|
context = 'spawn' if platform.system().lower() == 'windows' else 'fork' |
|
mp_ctx = mp.get_context(context) |
|
return mp_ctx |
|
|
|
|
|
@dataclass |
|
class SendPayload: |
|
proc_id: int |
|
|
|
req_id: str = field(default_factory=lambda: uuid.uuid1().hex) |
|
method: str = None |
|
args: List = field(default_factory=list) |
|
kwargs: Dict = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
class RecvPayload: |
|
proc_id: int |
|
req_id: str = None |
|
method: str = None |
|
data: Any = None |
|
err: Exception = None |
|
extra: Any = None |
|
|
|
|
|
class ReserveMethod(Enum): |
|
SHUTDOWN = "_shutdown" |
|
GETATTR = "_getattr" |
|
|
|
|
|
class ChildType(Enum): |
|
PROCESS = "process" |
|
THREAD = "thread" |
|
|
|
|
|
class Child(ABC): |
|
""" |
|
Abstract class of child process/thread. |
|
""" |
|
|
|
def __init__(self, proc_id: int, init: Union[Callable, object], **kwargs) -> None: |
|
self._proc_id = proc_id |
|
self._init = init |
|
self._recv_queue = None |
|
self._send_queue = None |
|
|
|
@abstractmethod |
|
def start(self, recv_queue: Union[mp.Queue, queue.Queue]): |
|
raise NotImplementedError |
|
|
|
def restart(self): |
|
self.shutdown() |
|
self.start(self._recv_queue) |
|
|
|
@abstractmethod |
|
def shutdown(self, timeout: Optional[float] = None): |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def send(self, payload: SendPayload): |
|
raise NotImplementedError |
|
|
|
def _target( |
|
self, |
|
proc_id: int, |
|
init: Union[Callable, object], |
|
send_queue: Union[mp.Queue, queue.Queue], |
|
recv_queue: Union[mp.Queue, queue.Queue], |
|
shm_buffer: Optional[Any] = None, |
|
shm_callback: Optional[Callable] = None |
|
): |
|
send_payload = SendPayload(proc_id=proc_id) |
|
if isinstance(init, Callable): |
|
child_ins = init() |
|
else: |
|
child_ins = init |
|
while True: |
|
try: |
|
send_payload: SendPayload = send_queue.get() |
|
if send_payload.method == ReserveMethod.SHUTDOWN: |
|
break |
|
if send_payload.method == ReserveMethod.GETATTR: |
|
data = getattr(child_ins, send_payload.args[0]) |
|
else: |
|
data = getattr(child_ins, send_payload.method)(*send_payload.args, **send_payload.kwargs) |
|
recv_payload = RecvPayload( |
|
proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, data=data |
|
) |
|
if shm_callback is not None and shm_buffer is not None: |
|
shm_callback(recv_payload, shm_buffer) |
|
recv_queue.put(recv_payload) |
|
except Exception as e: |
|
logging.warning(traceback.format_exc()) |
|
logging.warning("Error in child process! id: {}, error: {}".format(self._proc_id, e)) |
|
recv_payload = RecvPayload( |
|
proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, err=e |
|
) |
|
recv_queue.put(recv_payload) |
|
|
|
def __del__(self): |
|
self.shutdown() |
|
|
|
|
|
class ChildProcess(Child): |
|
|
|
def __init__( |
|
self, |
|
proc_id: int, |
|
init: Union[Callable, object], |
|
shm_buffer: Optional[Any] = None, |
|
shm_callback: Optional[Callable] = None, |
|
mp_ctx: Optional[BaseContext] = None, |
|
**kwargs |
|
) -> None: |
|
super().__init__(proc_id, init, **kwargs) |
|
self._proc = None |
|
self._mp_ctx = mp_ctx |
|
self._shm_buffer = shm_buffer |
|
self._shm_callback = shm_callback |
|
|
|
def start(self, recv_queue: mp.Queue): |
|
if self._proc is None: |
|
self._recv_queue = recv_queue |
|
ctx = self._mp_ctx or get_mp_ctx() |
|
self._send_queue = ctx.Queue() |
|
proc = ctx.Process( |
|
target=self._target, |
|
args=( |
|
self._proc_id, self._init, self._send_queue, self._recv_queue, self._shm_buffer, self._shm_callback |
|
), |
|
name="supervisor_child_{}_{}".format(self._proc_id, time.time()), |
|
daemon=True |
|
) |
|
proc.start() |
|
self._proc = proc |
|
|
|
def shutdown(self, timeout: Optional[float] = None): |
|
if self._proc: |
|
self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN)) |
|
self._proc.terminate() |
|
self._proc.join(timeout=timeout) |
|
if hasattr(self._proc, "close"): |
|
self._proc.close() |
|
self._proc = None |
|
self._send_queue.close() |
|
self._send_queue.join_thread() |
|
self._send_queue = None |
|
|
|
def send(self, payload: SendPayload): |
|
if self._send_queue is None: |
|
logging.warning("Child worker has been terminated or not started.") |
|
return |
|
self._send_queue.put(payload) |
|
|
|
|
|
class ChildThread(Child): |
|
|
|
def __init__(self, proc_id: int, init: Union[Callable, object], *args, **kwargs) -> None: |
|
super().__init__(proc_id, init, *args, **kwargs) |
|
self._thread = None |
|
|
|
def start(self, recv_queue: queue.Queue): |
|
if self._thread is None: |
|
self._recv_queue = recv_queue |
|
self._send_queue = queue.Queue() |
|
thread = threading.Thread( |
|
target=self._target, |
|
args=(self._proc_id, self._init, self._send_queue, self._recv_queue), |
|
name="supervisor_child_{}_{}".format(self._proc_id, time.time()), |
|
daemon=True |
|
) |
|
thread.start() |
|
self._thread = thread |
|
|
|
def shutdown(self, timeout: Optional[float] = None): |
|
if self._thread: |
|
self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN)) |
|
self._thread.join(timeout=timeout) |
|
self._thread = None |
|
self._send_queue = None |
|
|
|
def send(self, payload: SendPayload): |
|
if self._send_queue is None: |
|
logging.warning("Child worker has been terminated or not started.") |
|
return |
|
self._send_queue.put(payload) |
|
|
|
|
|
class Supervisor: |
|
|
|
TYPE_MAPPING = {ChildType.PROCESS: ChildProcess, ChildType.THREAD: ChildThread} |
|
|
|
def __init__(self, type_: ChildType, mp_ctx: Optional[BaseContext] = None) -> None: |
|
self._children: List[Child] = [] |
|
self._type = type_ |
|
self._child_class = self.TYPE_MAPPING[self._type] |
|
self._running = False |
|
self.__queue = None |
|
self._mp_ctx = mp_ctx or get_mp_ctx() |
|
|
|
def register( |
|
self, |
|
init: Union[Callable, object], |
|
shm_buffer: Optional[Any] = None, |
|
shm_callback: Optional[Callable] = None |
|
) -> None: |
|
proc_id = len(self._children) |
|
self._children.append( |
|
self._child_class(proc_id, init, shm_buffer=shm_buffer, shm_callback=shm_callback, mp_ctx=self._mp_ctx) |
|
) |
|
|
|
@property |
|
def _recv_queue(self) -> Union[queue.Queue, mp.Queue]: |
|
if not self.__queue: |
|
if self._type is ChildType.PROCESS: |
|
self.__queue = self._mp_ctx.Queue() |
|
elif self._type is ChildType.THREAD: |
|
self.__queue = queue.Queue() |
|
return self.__queue |
|
|
|
@_recv_queue.setter |
|
def _recv_queue(self, queue: Union[queue.Queue, mp.Queue]): |
|
self.__queue = queue |
|
|
|
def start_link(self) -> None: |
|
if not self._running: |
|
for child in self._children: |
|
child.start(recv_queue=self._recv_queue) |
|
self._running = True |
|
|
|
def send(self, payload: SendPayload) -> None: |
|
""" |
|
Overview: |
|
Send message to child process. |
|
Arguments: |
|
- payload (:obj:`SendPayload`): Send payload. |
|
""" |
|
if not self._running: |
|
logging.warning("Please call start_link before sending any payload to child process.") |
|
return |
|
self._children[payload.proc_id].send(payload) |
|
|
|
def recv(self, ignore_err: bool = False, timeout: float = None) -> RecvPayload: |
|
""" |
|
Overview: |
|
Wait for message from child process |
|
Arguments: |
|
- ignore_err (:obj:`bool`): If ignore_err is True, put the err in the property of recv_payload. \ |
|
Otherwise, an exception will be raised. |
|
- timeout (:obj:`float`): Timeout for queue.get, will raise an Empty exception if timeout. |
|
Returns: |
|
- recv_payload (:obj:`RecvPayload`): Recv payload. |
|
""" |
|
recv_payload: RecvPayload = self._recv_queue.get(timeout=timeout) |
|
if recv_payload.err and not ignore_err: |
|
raise recv_payload.err |
|
return recv_payload |
|
|
|
def recv_all( |
|
self, |
|
send_payloads: List[SendPayload], |
|
ignore_err: bool = False, |
|
callback: Callable = None, |
|
timeout: Optional[float] = None |
|
) -> List[RecvPayload]: |
|
""" |
|
Overview: |
|
Wait for messages with specific req ids until all ids are fulfilled. |
|
Arguments: |
|
- send_payloads (:obj:`List[SendPayload]`): Request payloads. |
|
- ignore_err (:obj:`bool`): If ignore_err is True, \ |
|
put the err in the property of recv_payload. Otherwise, an exception will be raised. \ |
|
This option will also ignore timeout error. |
|
- callback (:obj:`Callable`): Callback for each recv payload. |
|
- timeout (:obj:`Optional[float]`): Timeout when wait for responses. |
|
Returns: |
|
- recv_payload (:obj:`List[RecvPayload]`): Recv payload, may contain timeout error. |
|
""" |
|
assert send_payloads, "Req payload is empty!" |
|
recv_payloads = {} |
|
remain_payloads = {payload.req_id: payload for payload in send_payloads} |
|
unrelated_payloads = [] |
|
try: |
|
while remain_payloads: |
|
try: |
|
recv_payload: RecvPayload = self._recv_queue.get(block=True, timeout=timeout) |
|
if recv_payload.req_id in remain_payloads: |
|
del remain_payloads[recv_payload.req_id] |
|
recv_payloads[recv_payload.req_id] = recv_payload |
|
if recv_payload.err and not ignore_err: |
|
raise recv_payload.err |
|
if callback: |
|
callback(recv_payload, remain_payloads) |
|
else: |
|
unrelated_payloads.append(recv_payload) |
|
except queue.Empty: |
|
if ignore_err: |
|
req_ids = list(remain_payloads.keys()) |
|
logging.warning("Timeout ({}s) when receving payloads! Req ids: {}".format(timeout, req_ids)) |
|
for req_id in req_ids: |
|
send_payload = remain_payloads.pop(req_id) |
|
|
|
|
|
recv_payload = RecvPayload( |
|
proc_id=send_payload.proc_id, |
|
req_id=send_payload.req_id, |
|
method=send_payload.method, |
|
err=TimeoutError("Timeout on req_id ({})".format(req_id)) |
|
) |
|
recv_payloads[req_id] = recv_payload |
|
if callback: |
|
callback(recv_payload, remain_payloads) |
|
else: |
|
raise TimeoutError("Timeout ({}s) when receving payloads!".format(timeout)) |
|
finally: |
|
|
|
for payload in unrelated_payloads: |
|
self._recv_queue.put(payload) |
|
|
|
|
|
return [recv_payloads[p.req_id] for p in send_payloads] |
|
|
|
def shutdown(self, timeout: Optional[float] = None) -> None: |
|
if self._running: |
|
for child in self._children: |
|
child.shutdown(timeout=timeout) |
|
self._cleanup_queue() |
|
self._running = False |
|
|
|
def _cleanup_queue(self): |
|
while True: |
|
while not self._recv_queue.empty(): |
|
self._recv_queue.get() |
|
time.sleep(0.1) |
|
if self._recv_queue.empty(): |
|
break |
|
if hasattr(self._recv_queue, "close"): |
|
self._recv_queue.close() |
|
self._recv_queue.join_thread() |
|
self._recv_queue = None |
|
|
|
def __getattr__(self, key: str) -> List[Any]: |
|
assert self._running, "Supervisor is not running, please call start_link first!" |
|
send_payloads = [] |
|
for i, child in enumerate(self._children): |
|
payload = SendPayload(proc_id=i, method=ReserveMethod.GETATTR, args=[key]) |
|
send_payloads.append(payload) |
|
child.send(payload) |
|
return [payload.data for payload in self.recv_all(send_payloads)] |
|
|
|
def get_child_attr(self, proc_id: str, key: str) -> Any: |
|
""" |
|
Overview: |
|
Get attr of one child process instance. |
|
Arguments: |
|
- proc_id (:obj:`str`): Proc id. |
|
- key (:obj:`str`): Attribute key. |
|
Returns: |
|
- attr (:obj:`Any`): Attribute of child. |
|
""" |
|
assert self._running, "Supervisor is not running, please call start_link first!" |
|
payload = SendPayload(proc_id=proc_id, method=ReserveMethod.GETATTR, args=[key]) |
|
self._children[proc_id].send(payload) |
|
payloads = self.recv_all([payload]) |
|
return payloads[0].data |
|
|
|
def __del__(self) -> None: |
|
self.shutdown(timeout=5) |
|
self._children.clear() |
|
|