gomoku / DI-engine /ding /data /buffer /deque_buffer.py
zjowowen's picture
init space
079c32c
raw
history blame
15.7 kB
import os
import itertools
import random
import uuid
from ditk import logging
import hickle
from typing import Any, Iterable, List, Optional, Tuple, Union
from collections import Counter
from collections import defaultdict, deque, OrderedDict
from ding.data.buffer import Buffer, apply_middleware, BufferedData
from ding.utils import fastcopy
from ding.torch_utils import get_null_data
class BufferIndex():
"""
Overview:
Save index string and offset in key value pair.
"""
def __init__(self, maxlen: int, *args, **kwargs):
self.maxlen = maxlen
self.__map = OrderedDict(*args, **kwargs)
self._last_key = next(reversed(self.__map)) if len(self) > 0 else None
self._cumlen = len(self.__map)
def get(self, key: str) -> int:
value = self.__map[key]
value = value % self._cumlen + min(0, (self.maxlen - self._cumlen))
return value
def __len__(self) -> int:
return len(self.__map)
def has(self, key: str) -> bool:
return key in self.__map
def append(self, key: str):
self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0
self._last_key = key
self._cumlen += 1
if len(self) > self.maxlen:
self.__map.popitem(last=False)
def clear(self):
self.__map = OrderedDict()
self._last_key = None
self._cumlen = 0
class DequeBuffer(Buffer):
"""
Overview:
A buffer implementation based on the deque structure.
"""
def __init__(self, size: int, sliced: bool = False) -> None:
"""
Overview:
The initialization method of DequeBuffer.
Arguments:
- size (:obj:`int`): The maximum number of objects that the buffer can hold.
- sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group
"""
super().__init__(size=size)
self.storage = deque(maxlen=size)
self.indices = BufferIndex(maxlen=size)
self.sliced = sliced
# Meta index is a dict which uses deque as values
self.meta_index = {}
@apply_middleware("push")
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
"""
Overview:
The method that input the objects and the related meta information into the buffer.
Arguments:
- data (:obj:`Any`): The input object which can be in any format.
- meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\
category, label, priority, etc. Default to ``None``.
"""
return self._push(data, meta)
@apply_middleware("sample")
def sample(
self,
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False,
groupby: Optional[str] = None,
unroll_len: Optional[int] = None
) -> Union[List[BufferedData], List[List[BufferedData]]]:
"""
Overview:
The method that randomly sample data from the buffer or retrieve certain data by indices.
Arguments:
- size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer.
If ``indices`` is not specified, the ``size`` is required to randomly sample the\
corresponding number of objects from the buffer.
- indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices.
Default to ``None``.
- replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\
determines whether the previous samples will be put back into the buffer for subsequent\
sampling. Default to ``False``, it means that duplicate samples will not appear in one\
``sample`` call.
- sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\
it means no restrictions on the range of indices for the sampling process.
- ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\
than the required size. Default to ``False``.
- groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\
target size of object groups.
- unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\
``groupby`` is activated.
Returns:
- sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result.
"""
storage = self.storage
if sample_range:
storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))
# Size and indices
assert size or indices, "One of size and indices must not be empty."
if (size and indices) and (size != len(indices)):
raise AssertionError("Size and indices length must be equal.")
if not size:
size = len(indices)
# Indices and groupby
assert not (indices and groupby), "Cannot use groupby and indicex at the same time."
# Groupby and unroll_len
assert not unroll_len or (
unroll_len and groupby
), "Parameter unroll_len needs to be used in conjunction with groupby."
value_error = None
sampled_data = []
if indices:
indices_set = set(indices)
hashed_data = filter(lambda item: item.index in indices_set, storage)
hashed_data = map(lambda item: (item.index, item), hashed_data)
hashed_data = dict(hashed_data)
# Re-sample and return in indices order
sampled_data = [hashed_data[index] for index in indices]
elif groupby:
sampled_data = self._sample_by_group(
size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced
)
else:
if replace:
sampled_data = random.choices(storage, k=size)
else:
try:
sampled_data = random.sample(storage, k=size)
except ValueError as e:
value_error = e
if value_error or len(sampled_data) != size:
if ignore_insufficient:
logging.warning(
"Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}".
format(self.count(), size)
)
else:
raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count()))
sampled_data = self._independence(sampled_data)
return sampled_data
@apply_middleware("update")
def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
"""
Overview:
the method that update data and the related meta information with a certain index.
Arguments:
- data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\
to ``None``, nothing will happen to the old record.
- meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one.
"""
if not self.indices.has(index):
return False
i = self.indices.get(index)
item = self.storage[i]
if data is not None:
item.data = data
if meta is not None:
item.meta = meta
for key in self.meta_index:
self.meta_index[key][i] = meta[key] if key in meta else None
return True
@apply_middleware("delete")
def delete(self, indices: Union[str, Iterable[str]]) -> None:
"""
Overview:
The method that delete the data and related meta information by specific indices.
Arguments:
- indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer.
"""
if isinstance(indices, str):
indices = [indices]
del_idx = []
for index in indices:
if self.indices.has(index):
del_idx.append(self.indices.get(index))
if len(del_idx) == 0:
return
del_idx = sorted(del_idx, reverse=True)
for idx in del_idx:
del self.storage[idx]
remain_indices = [item.index for item in self.storage]
key_value_pairs = zip(remain_indices, range(len(indices)))
self.indices = BufferIndex(self.storage.maxlen, key_value_pairs)
def save_data(self, file_name: str):
if not os.path.exists(os.path.dirname(file_name)):
# If the folder for the specified file does not exist, it will be created.
if os.path.dirname(file_name) != "":
os.makedirs(os.path.dirname(file_name))
hickle.dump(
py_obj=(
self.storage,
self.indices,
self.meta_index,
), file_obj=file_name
)
def load_data(self, file_name: str):
self.storage, self.indices, self.meta_index = hickle.load(file_name)
def count(self) -> int:
"""
Overview:
The method that returns the current length of the buffer.
"""
return len(self.storage)
def get(self, idx: int) -> BufferedData:
"""
Overview:
The method that returns the BufferedData object given a specific index.
"""
return self.storage[idx]
@apply_middleware("clear")
def clear(self) -> None:
"""
Overview:
The method that clear all data, indices, and the meta information in the buffer.
"""
self.storage.clear()
self.indices.clear()
self.meta_index = {}
def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
index = uuid.uuid1().hex
if meta is None:
meta = {}
buffered = BufferedData(data=data, index=index, meta=meta)
self.storage.append(buffered)
self.indices.append(index)
# Add meta index
for key in self.meta_index:
self.meta_index[key].append(meta[key] if key in meta else None)
return buffered
def _independence(
self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]]
) -> Union[List[BufferedData], List[List[BufferedData]]]:
"""
Overview:
Make sure that each record is different from each other, but remember that this function
is different from clone_object. You may change the data in the buffer by modifying a record.
Arguments:
- buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data,
can be nested if groupby has been set.
"""
if len(buffered_samples) == 0:
return buffered_samples
occurred = defaultdict(int)
for i, buffered in enumerate(buffered_samples):
if isinstance(buffered, list):
sampled_list = buffered
# Loop over nested samples
for j, buffered in enumerate(sampled_list):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
sampled_list[j] = fastcopy.copy(buffered)
elif isinstance(buffered, BufferedData):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
buffered_samples[i] = fastcopy.copy(buffered)
else:
raise Exception("Get unexpected buffered type {}".format(type(buffered)))
return buffered_samples
def _sample_by_group(
self,
size: int,
groupby: str,
replace: bool = False,
unroll_len: Optional[int] = None,
storage: deque = None,
sliced: bool = False
) -> List[List[BufferedData]]:
"""
Overview:
Sampling by `group` instead of records, the result will be a collection
of lists with a length of `size`, but the length of each list may be different from other lists.
"""
if storage is None:
storage = self.storage
if groupby not in self.meta_index:
self._create_index(groupby)
def filter_by_unroll_len():
"Filter groups by unroll len, ensure count of items in each group is greater than unroll_len."
group_count = Counter(self.meta_index[groupby])
group_names = []
for key, count in group_count.items():
if count >= unroll_len:
group_names.append(key)
return group_names
if unroll_len and unroll_len > 1:
group_names = filter_by_unroll_len()
if len(group_names) == 0:
return []
else:
group_names = list(set(self.meta_index[groupby]))
sampled_groups = []
if replace:
sampled_groups = random.choices(group_names, k=size)
else:
try:
sampled_groups = random.sample(group_names, k=size)
except ValueError:
raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names)))
# Build dict like {"group name": [records]}
sampled_data = defaultdict(list)
for buffered in storage:
meta_value = buffered.meta[groupby] if groupby in buffered.meta else None
if meta_value in sampled_groups:
sampled_data[buffered.meta[groupby]].append(buffered)
final_sampled_data = []
for group in sampled_groups:
seq_data = sampled_data[group]
# Filter records by unroll_len
if unroll_len:
# slice b unroll_len. If don’t do this, more likely obtain duplicate data, \
# and the training will easily crash.
if sliced:
start_indice = random.choice(range(max(1, len(seq_data))))
start_indice = start_indice // unroll_len
if start_indice == (len(seq_data) - 1) // unroll_len:
seq_data = seq_data[-unroll_len:]
else:
seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len]
else:
start_indice = random.choice(range(max(1, len(seq_data) - unroll_len)))
seq_data = seq_data[start_indice:start_indice + unroll_len]
final_sampled_data.append(seq_data)
return final_sampled_data
def _create_index(self, meta_key: str):
self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen)
for data in self.storage:
self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None)
def __iter__(self) -> deque:
return iter(self.storage)
def __copy__(self) -> "DequeBuffer":
buffer = type(self)(size=self.storage.maxlen)
buffer.storage = self.storage
buffer.meta_index = self.meta_index
buffer.indices = self.indices
return buffer