|
from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING |
|
from collections import defaultdict |
|
from ding.data.buffer import BufferedData |
|
if TYPE_CHECKING: |
|
from ding.data.buffer.buffer import Buffer |
|
|
|
|
|
def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable: |
|
""" |
|
Overview: |
|
This middleware aims to check the usage times of data in buffer. If the usage times of a data is |
|
greater than or equal to max_use, this data will be removed from buffer as soon as possible. |
|
Arguments: |
|
- max_use (:obj:`int`): The max reused (resampled) count for any individual object. |
|
""" |
|
|
|
use_count = defaultdict(int) |
|
|
|
def _need_delete(item: BufferedData) -> bool: |
|
nonlocal use_count |
|
idx = item.index |
|
use_count[idx] += 1 |
|
item.meta['use_count'] = use_count[idx] |
|
if use_count[idx] >= max_use: |
|
return True |
|
else: |
|
return False |
|
|
|
def _check_use_count(sampled_data: List[BufferedData]): |
|
delete_indices = [item.index for item in filter(_need_delete, sampled_data)] |
|
buffer_.delete(delete_indices) |
|
for index in delete_indices: |
|
del use_count[index] |
|
|
|
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: |
|
sampled_data = chain(*args, **kwargs) |
|
if len(sampled_data) == 0: |
|
return sampled_data |
|
|
|
if isinstance(sampled_data[0], BufferedData): |
|
_check_use_count(sampled_data) |
|
else: |
|
for grouped_data in sampled_data: |
|
_check_use_count(grouped_data) |
|
return sampled_data |
|
|
|
def _use_time_check(action: str, chain: Callable, *args, **kwargs) -> Any: |
|
if action == "sample": |
|
return sample(chain, *args, **kwargs) |
|
return chain(*args, **kwargs) |
|
|
|
return _use_time_check |
|
|