zjowowen's picture
init space
079c32c
raw
history blame
1.83 kB
from typing import Callable, Any, List, TYPE_CHECKING
if TYPE_CHECKING:
from ding.data.buffer.buffer import Buffer
def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Callable:
"""
Overview:
This middleware aims to check staleness before each sample operation,
staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is,
If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible.
Arguments:
- max_staleness (:obj:`int`): The maximum legal span between the time of collecting and time of sampling.
"""
def push(next: Callable, data: Any, *args, **kwargs) -> Any:
assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[
'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': <iter>}"
return next(data, *args, **kwargs)
def sample(next: Callable, train_iter_sample_data: int, *args, **kwargs) -> List[Any]:
delete_index = []
for i, item in enumerate(buffer_.storage):
index, meta = item.index, item.meta
staleness = train_iter_sample_data - meta['train_iter_data_collected']
meta['staleness'] = staleness
if staleness > max_staleness:
delete_index.append(index)
for index in delete_index:
buffer_.delete(index)
data = next(*args, **kwargs)
return data
def _staleness_check(action: str, next: Callable, *args, **kwargs) -> Any:
if action == "push":
return push(next, *args, **kwargs)
elif action == "sample":
return sample(next, *args, **kwargs)
return next(*args, **kwargs)
return _staleness_check