from abc import abstractmethod, ABC from typing import Any, List, Optional, Union, Callable import copy from dataclasses import dataclass from functools import wraps from ding.utils import fastcopy def apply_middleware(func_name: str): def wrap_func(base_func: Callable): @wraps(base_func) def handler(buffer, *args, **kwargs): """ Overview: The real processing starts here, we apply the middleware one by one, each middleware will receive next `chained` function, which is an executor of next middleware. You can change the input arguments to the next `chained` middleware, and you also can get the return value from the next middleware, so you have the maximum freedom to choose at what stage to implement your method. """ def wrap_handler(middleware, *args, **kwargs): if len(middleware) == 0: return base_func(buffer, *args, **kwargs) def chain(*args, **kwargs): return wrap_handler(middleware[1:], *args, **kwargs) func = middleware[0] return func(func_name, chain, *args, **kwargs) return wrap_handler(buffer._middleware, *args, **kwargs) return handler return wrap_func @dataclass class BufferedData: data: Any index: str meta: dict # Register new dispatcher on fastcopy to avoid circular references def _copy_buffereddata(d: BufferedData) -> BufferedData: return BufferedData(data=fastcopy.copy(d.data), index=d.index, meta=fastcopy.copy(d.meta)) fastcopy.dispatch[BufferedData] = _copy_buffereddata class Buffer(ABC): """ Buffer is an abstraction of device storage, third-party services or data structures, For example, memory queue, sum-tree, redis, or di-store. """ def __init__(self, size: int) -> None: self._middleware = [] self.size = size @abstractmethod def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: """ Overview: Push data and it's meta information in buffer. Arguments: - data (:obj:`Any`): The data which will be pushed into buffer. - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness. Returns: - buffered_data (:obj:`BufferedData`): The pushed data. """ raise NotImplementedError @abstractmethod def sample( self, size: Optional[int] = None, indices: Optional[List[str]] = None, replace: bool = False, sample_range: Optional[slice] = None, ignore_insufficient: bool = False, groupby: Optional[str] = None, unroll_len: Optional[int] = None ) -> Union[List[BufferedData], List[List[BufferedData]]]: """ Overview: Sample data with length ``size``. Arguments: - size (:obj:`Optional[int]`): The number of the data that will be sampled. - indices (:obj:`Optional[List[str]]`): Sample with multiple indices. - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer. - sample_range (:obj:`slice`): Sample range slice. - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size with no repetition will not cause an exception. - groupby (:obj:`Optional[str]`): Groupby key in meta, i.e. groupby="episode" - unroll_len (:obj:`Optional[int]`): Number of consecutive frames within a group. Returns: - sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`): A list of data with length ``size``, may be nested if groupby is set. """ raise NotImplementedError @abstractmethod def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: """ Overview: Update data and meta by index Arguments: - index (:obj:`str`): Index of data. - data (:obj:`any`): Pure data. - meta (:obj:`dict`): Meta information. Returns: - success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false. """ raise NotImplementedError @abstractmethod def delete(self, index: str): """ Overview: Delete one data sample by index Arguments: - index (:obj:`str`): Index """ raise NotImplementedError @abstractmethod def save_data(self, file_name: str): """ Overview: Save buffer data into a file. Arguments: - file_name (:obj:`str`): file name of buffer data """ raise NotImplementedError @abstractmethod def load_data(self, file_name: str): """ Overview: Load buffer data from a file. Arguments: - file_name (:obj:`str`): file name of buffer data """ raise NotImplementedError @abstractmethod def count(self) -> int: raise NotImplementedError @abstractmethod def clear(self) -> None: raise NotImplementedError @abstractmethod def get(self, idx: int) -> BufferedData: """ Overview: Get item by subscript index Arguments: - idx (:obj:`int`): Subscript index Returns: - buffered_data (:obj:`BufferedData`): Item from buffer """ raise NotImplementedError def use(self, func: Callable) -> "Buffer": """ Overview: Use algorithm middleware to modify the behavior of the buffer. Every middleware should be a callable function, it will receive three argument parts, including: 1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage. 2. The functions called by the user, there are three methods named `push` , `sample` and `clear` , \ so you can use these function name to decide which action to choose. 3. The remaining arguments passed by the user to the original function, will be passed in `*args` . Each middleware handler should return two parts of the value, including: 1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, \ no more middleware will be executed during this execution 2. The remaining values, will be passed to the next middleware or the default function in the buffer. Arguments: - func (:obj:`Callable`): The middleware handler Returns: - buffer (:obj:`Buffer`): The instance self """ self._middleware.append(func) return self def view(self) -> "Buffer": r""" Overview: A view is a new instance of buffer, with a deepcopy of every property except the storage. The storage is shared among all the buffer instances. Returns: - buffer (:obj:`Buffer`): The instance self """ return copy.copy(self) def __copy__(self) -> "Buffer": raise NotImplementedError def __len__(self) -> int: return self.count() def __getitem__(self, idx: int) -> BufferedData: return self.get(idx)