zjowowen's picture
init space
079c32c
raw
history blame
38.2 kB
import os
import copy
import time
from typing import Union, Any, Optional, List, Dict, Tuple
import numpy as np
import hickle
from ding.worker.replay_buffer import IBuffer
from ding.utils import SumSegmentTree, MinSegmentTree, BUFFER_REGISTRY
from ding.utils import LockContext, LockContextType, build_logger, get_rank
from ding.utils.autolog import TickTime
from .utils import UsedDataRemover, generate_id, SampledDataAttrMonitor, PeriodicThruputMonitor, ThruputController
def to_positive_index(idx: Union[int, None], size: int) -> int:
if idx is None or idx >= 0:
return idx
else:
return size + idx
@BUFFER_REGISTRY.register('advanced')
class AdvancedReplayBuffer(IBuffer):
r"""
Overview:
Prioritized replay buffer derived from ``NaiveReplayBuffer``.
This replay buffer adds:
1) Prioritized experience replay implemented by segment tree.
2) Data quality monitor. Monitor use count and staleness of each data.
3) Throughput monitor and control.
4) Logger. Log 2) and 3) in tensorboard or text.
Interface:
start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
Property:
beta, replay_buffer_size, push_count
"""
config = dict(
type='advanced',
# Max length of the buffer.
replay_buffer_size=4096,
# Max use times of one data in the buffer. Data will be removed once used for too many times.
max_use=float("inf"),
# Max staleness time duration of one data in the buffer; Data will be removed if
# the duration from collecting to training is too long, i.e. The data is too stale.
max_staleness=float("inf"),
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
alpha=0.6,
# (Float type) How much correction is used: 0 means no correction while 1 means full correction
beta=0.4,
# Anneal step for beta: 0 means no annealing
anneal_step=int(1e5),
# Whether to track the used data. Used data means they are removed out of buffer and would never be used again.
enable_track_used_data=False,
# Whether to deepcopy data when willing to insert and sample data. For security purpose.
deepcopy=False,
thruput_controller=dict(
# Rate limit. The ratio of "Sample Count" to "Push Count" should be in [min, max] range.
# If greater than max ratio, return `None` when calling ``sample```;
# If smaller than min ratio, throw away the new data when calling ``push``.
push_sample_rate_limit=dict(
max=float("inf"),
min=0,
),
# Controller will take how many seconds into account, i.e. For the past `window_seconds` seconds,
# sample_push_rate will be calculated and campared with `push_sample_rate_limit`.
window_seconds=30,
# The minimum ratio that buffer must satisfy before anything can be sampled.
# The ratio is calculated by "Valid Count" divided by "Batch Size".
# E.g. sample_min_limit_ratio = 2.0, valid_count = 50, batch_size = 32, it is forbidden to sample.
sample_min_limit_ratio=1,
),
# Monitor configuration for monitor and logger to use. This part does not affect buffer's function.
monitor=dict(
sampled_data_attr=dict(
# Past datas will be used for moving average.
average_range=5,
# Print data attributes every `print_freq` samples.
print_freq=200, # times
),
periodic_thruput=dict(
# Every `seconds` seconds, thruput(push/sample/remove count) will be printed.
seconds=60,
),
),
)
def __init__(
self,
cfg: dict,
tb_logger: Optional['SummaryWriter'] = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'buffer',
) -> int:
"""
Overview:
Initialize the buffer
Arguments:
- cfg (:obj:`dict`): Config dict.
- tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode.
- exp_name (:obj:`Optional[str]`): Name of this experiment.
- instance_name (:obj:`Optional[str]`): Name of this instance.
"""
self._exp_name = exp_name
self._instance_name = instance_name
self._end_flag = False
self._cfg = cfg
self._rank = get_rank()
self._replay_buffer_size = self._cfg.replay_buffer_size
self._deepcopy = self._cfg.deepcopy
# ``_data`` is a circular queue to store data (full data or meta data)
self._data = [None for _ in range(self._replay_buffer_size)]
# Current valid data count, indicating how many elements in ``self._data`` is valid.
self._valid_count = 0
# How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``.
self._push_count = 0
# Point to the tail position where next data can be inserted, i.e. latest inserted data's next position.
self._tail = 0
# Is used to generate a unique id for each data: If a new data is inserted, its unique id will be this.
self._next_unique_id = 0
# Lock to guarantee thread safe
self._lock = LockContext(type_=LockContextType.THREAD_LOCK)
# Point to the head of the circular queue. The true data is the stalest(oldest) data in this queue.
# Because buffer would remove data due to staleness or use count, and at the beginning when queue is not
# filled with data head would always be 0, so ``head`` may be not equal to ``tail``;
# Otherwise, they two should be the same. Head is used to optimize staleness check in ``_sample_check``.
self._head = 0
# use_count is {position_idx: use_count}
self._use_count = {idx: 0 for idx in range(self._cfg.replay_buffer_size)}
# Max priority till now. Is used to initizalize a data's priority if "priority" is not passed in with the data.
self._max_priority = 1.0
# A small positive number to avoid edge-case, e.g. "priority" == 0.
self._eps = 1e-5
# Data check function list, used in ``_append`` and ``_extend``. This buffer requires data to be dict.
self.check_list = [lambda x: isinstance(x, dict)]
self._max_use = self._cfg.max_use
self._max_staleness = self._cfg.max_staleness
self.alpha = self._cfg.alpha
assert 0 <= self.alpha <= 1, self.alpha
self._beta = self._cfg.beta
assert 0 <= self._beta <= 1, self._beta
self._anneal_step = self._cfg.anneal_step
if self._anneal_step != 0:
self._beta_anneal_step = (1 - self._beta) / self._anneal_step
# Prioritized sample.
# Capacity needs to be the power of 2.
capacity = int(np.power(2, np.ceil(np.log2(self.replay_buffer_size))))
# Sum segtree and min segtree are used to sample data according to priority.
self._sum_tree = SumSegmentTree(capacity)
self._min_tree = MinSegmentTree(capacity)
# Thruput controller
push_sample_rate_limit = self._cfg.thruput_controller.push_sample_rate_limit
self._always_can_push = True if push_sample_rate_limit['max'] == float('inf') else False
self._always_can_sample = True if push_sample_rate_limit['min'] == 0 else False
self._use_thruput_controller = not self._always_can_push or not self._always_can_sample
if self._use_thruput_controller:
self._thruput_controller = ThruputController(self._cfg.thruput_controller)
self._sample_min_limit_ratio = self._cfg.thruput_controller.sample_min_limit_ratio
assert self._sample_min_limit_ratio >= 1
# Monitor & Logger
monitor_cfg = self._cfg.monitor
if self._rank == 0:
if tb_logger is not None:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name),
self._instance_name,
)
else:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = None
self._start_time = time.time()
# Sampled data attributes.
self._cur_learner_iter = -1
self._cur_collector_envstep = -1
self._sampled_data_attr_print_count = 0
self._sampled_data_attr_monitor = SampledDataAttrMonitor(
TickTime(), expire=monitor_cfg.sampled_data_attr.average_range
)
self._sampled_data_attr_print_freq = monitor_cfg.sampled_data_attr.print_freq
# Periodic thruput.
if self._rank == 0:
self._periodic_thruput_monitor = PeriodicThruputMonitor(
self._instance_name, monitor_cfg.periodic_thruput, self._logger, self._tb_logger
)
# Used data remover
self._enable_track_used_data = self._cfg.enable_track_used_data
if self._enable_track_used_data:
self._used_data_remover = UsedDataRemover()
def start(self) -> None:
"""
Overview:
Start the buffer's used_data_remover thread if enables track_used_data.
"""
if self._enable_track_used_data:
self._used_data_remover.start()
def close(self) -> None:
"""
Overview:
Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data.
Join periodic throughtput monitor, flush tensorboard logger.
"""
if self._end_flag:
return
self._end_flag = True
self.clear()
if self._rank == 0:
self._periodic_thruput_monitor.close()
self._tb_logger.flush()
self._tb_logger.close()
if self._enable_track_used_data:
self._used_data_remover.close()
def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]:
"""
Overview:
Sample data with length ``size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
means only sample among the last 10 data
Returns:
- sample_data (:obj:`list`): A list of data with length ``size``
ReturnsKeys:
- necessary: original keys(e.g. `obs`, `action`, `next_obs`, `reward`, `info`), \
`replay_unique_id`, `replay_buffer_idx`
- optional(if use priority): `IS`, `priority`
"""
if size == 0:
return []
can_sample_stalenss, staleness_info = self._sample_check(size, cur_learner_iter)
if self._always_can_sample:
can_sample_thruput, thruput_info = True, "Always can sample because push_sample_rate_limit['min'] == 0"
else:
can_sample_thruput, thruput_info = self._thruput_controller.can_sample(size)
if not can_sample_stalenss or not can_sample_thruput:
self._logger.info(
'Refuse to sample due to -- \nstaleness: {}, {} \nthruput: {}, {}'.format(
not can_sample_stalenss, staleness_info, not can_sample_thruput, thruput_info
)
)
return None
with self._lock:
indices = self._get_indices(size, sample_range)
result = self._sample_with_indices(indices, cur_learner_iter)
# Deepcopy ``result``'s same indice datas in case ``self._get_indices`` may get datas with
# the same indices, i.e. the same datas would be sampled afterwards.
# if self._deepcopy==True -> all data is different
# if len(indices) == len(set(indices)) -> no duplicate data
if not self._deepcopy and len(indices) != len(set(indices)):
for i, index in enumerate(indices):
tmp = []
for j in range(i + 1, size):
if index == indices[j]:
tmp.append(j)
for j in tmp:
result[j] = copy.deepcopy(result[j])
self._monitor_update_of_sample(result, cur_learner_iter)
return result
def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
r"""
Overview:
Push a data into buffer.
Arguments:
- data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \
(in `Any` type), or many(int `List[Any]` type).
- cur_collector_envstep (:obj:`int`): Collector's current env step.
"""
push_size = len(data) if isinstance(data, list) else 1
if self._always_can_push:
can_push, push_info = True, "Always can push because push_sample_rate_limit['max'] == float('inf')"
else:
can_push, push_info = self._thruput_controller.can_push(push_size)
if not can_push:
self._logger.info('Refuse to push because {}'.format(push_info))
return
if isinstance(data, list):
self._extend(data, cur_collector_envstep)
else:
self._append(data, cur_collector_envstep)
def save_data(self, file_name: str):
if not os.path.exists(os.path.dirname(file_name)):
if os.path.dirname(file_name) != "":
os.makedirs(os.path.dirname(file_name))
hickle.dump(py_obj=self._data, file_obj=file_name)
def load_data(self, file_name: str):
self.push(hickle.load(file_name), 0)
def _sample_check(self, size: int, cur_learner_iter: int) -> Tuple[bool, str]:
r"""
Overview:
Do preparations for sampling and check whether data is enough for sampling
Preparation includes removing stale datas in ``self._data``.
Check includes judging whether this buffer has more than ``size`` datas to sample.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
Returns:
- can_sample (:obj:`bool`): Whether this buffer can sample enough data.
- str_info (:obj:`str`): Str type info, explaining why cannot sample. (If can sample, return "Can sample")
.. note::
This function must be called before data sample.
"""
staleness_remove_count = 0
with self._lock:
if self._max_staleness != float("inf"):
p = self._head
while True:
if self._data[p] is not None:
staleness = self._calculate_staleness(p, cur_learner_iter)
if staleness >= self._max_staleness:
self._remove(p)
staleness_remove_count += 1
else:
# Since the circular queue ``self._data`` guarantees that data's staleness is decreasing
# from index self._head to index self._tail - 1, we can jump out of the loop as soon as
# meeting a fresh enough data
break
p = (p + 1) % self._replay_buffer_size
if p == self._tail:
# Traverse a circle and go back to the tail, which means can stop staleness checking now
break
str_info = "Remove {} elements due to staleness. ".format(staleness_remove_count)
if self._valid_count / size < self._sample_min_limit_ratio:
str_info += "Not enough for sampling. valid({}) / sample({}) < sample_min_limit_ratio({})".format(
self._valid_count, size, self._sample_min_limit_ratio
)
return False, str_info
else:
str_info += "Can sample."
return True, str_info
def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None:
r"""
Overview:
Append a data item into queue.
Add two keys in data:
- replay_unique_id: The data item's unique id, using ``generate_id`` to generate it.
- replay_buffer_idx: The data item's position index in the queue, this position may already have an \
old element, then it would be replaced by this new input one. using ``self._tail`` to locate.
Arguments:
- ori_data (:obj:`Any`): The data which will be inserted.
- cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard.
"""
with self._lock:
if self._deepcopy:
data = copy.deepcopy(ori_data)
else:
data = ori_data
try:
assert self._data_check(data)
except AssertionError:
# If data check fails, log it and return without any operations.
self._logger.info('Illegal data type [{}], reject it...'.format(type(data)))
return
self._push_count += 1
# remove->set weight->set data
if self._data[self._tail] is not None:
self._head = (self._tail + 1) % self._replay_buffer_size
self._remove(self._tail)
data['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id)
data['replay_buffer_idx'] = self._tail
self._set_weight(data)
self._data[self._tail] = data
self._valid_count += 1
if self._rank == 0:
self._periodic_thruput_monitor.valid_count = self._valid_count
self._tail = (self._tail + 1) % self._replay_buffer_size
self._next_unique_id += 1
self._monitor_update_of_push(1, cur_collector_envstep)
def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None:
r"""
Overview:
Extend a data list into queue.
Add two keys in each data item, you can refer to ``_append`` for more details.
Arguments:
- ori_data (:obj:`List[Any]`): The data list.
- cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard.
"""
with self._lock:
if self._deepcopy:
data = copy.deepcopy(ori_data)
else:
data = ori_data
check_result = [self._data_check(d) for d in data]
# Only keep data items that pass ``_data_check`.
valid_data = [d for d, flag in zip(data, check_result) if flag]
length = len(valid_data)
# When updating ``_data`` and ``_use_count``, should consider two cases regarding
# the relationship between "tail + data length" and "queue max length" to check whether
# data will exceed beyond queue's max length limitation.
if self._tail + length <= self._replay_buffer_size:
for j in range(self._tail, self._tail + length):
if self._data[j] is not None:
self._head = (j + 1) % self._replay_buffer_size
self._remove(j)
for i in range(length):
valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i)
valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size
self._set_weight(valid_data[i])
self._push_count += 1
self._data[self._tail:self._tail + length] = valid_data
else:
data_start = self._tail
valid_data_start = 0
residual_num = len(valid_data)
while True:
space = self._replay_buffer_size - data_start
L = min(space, residual_num)
for j in range(data_start, data_start + L):
if self._data[j] is not None:
self._head = (j + 1) % self._replay_buffer_size
self._remove(j)
for i in range(valid_data_start, valid_data_start + L):
valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i)
valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size
self._set_weight(valid_data[i])
self._push_count += 1
self._data[data_start:data_start + L] = valid_data[valid_data_start:valid_data_start + L]
residual_num -= L
if residual_num <= 0:
break
else:
data_start = 0
valid_data_start += L
self._valid_count += len(valid_data)
if self._rank == 0:
self._periodic_thruput_monitor.valid_count = self._valid_count
# Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer.
self._tail = (self._tail + length) % self._replay_buffer_size
self._next_unique_id += length
self._monitor_update_of_push(length, cur_collector_envstep)
def update(self, info: dict) -> None:
r"""
Overview:
Update a data's priority. Use `repaly_buffer_idx` to locate, and use `replay_unique_id` to verify.
Arguments:
- info (:obj:`dict`): Info dict containing all necessary keys for priority update.
ArgumentsKeys:
- necessary: `replay_unique_id`, `replay_buffer_idx`, `priority`. All values are lists with the same length.
"""
with self._lock:
if 'priority' not in info:
return
data = [info['replay_unique_id'], info['replay_buffer_idx'], info['priority']]
for id_, idx, priority in zip(*data):
# Only if the data still exists in the queue, will the update operation be done.
if self._data[idx] is not None \
and self._data[idx]['replay_unique_id'] == id_: # Verify the same transition(data)
assert priority >= 0, priority
assert self._data[idx]['replay_buffer_idx'] == idx
self._data[idx]['priority'] = priority + self._eps # Add epsilon to avoid priority == 0
self._set_weight(self._data[idx])
# Update max priority
self._max_priority = max(self._max_priority, priority)
else:
self._logger.debug(
'[Skip Update]: buffer_idx: {}; id_in_buffer: {}; id_in_update_info: {}'.format(
idx, id_, priority
)
)
def clear(self) -> None:
"""
Overview:
Clear all the data and reset the related variables.
"""
with self._lock:
for i in range(len(self._data)):
self._remove(i)
assert self._valid_count == 0, self._valid_count
self._head = 0
self._tail = 0
self._max_priority = 1.0
def __del__(self) -> None:
"""
Overview:
Call ``close`` to delete the object.
"""
if not self._end_flag:
self.close()
def _set_weight(self, data: Dict) -> None:
r"""
Overview:
Set sumtree and mintree's weight of the input data according to its priority.
If input data does not have key "priority", it would set to ``self._max_priority`` instead.
Arguments:
- data (:obj:`Dict`): The data whose priority(weight) in segement tree should be set/updated.
"""
if 'priority' not in data.keys() or data['priority'] is None:
data['priority'] = self._max_priority
weight = data['priority'] ** self.alpha
idx = data['replay_buffer_idx']
self._sum_tree[idx] = weight
self._min_tree[idx] = weight
def _data_check(self, d: Any) -> bool:
r"""
Overview:
Data legality check, using rules(functions) in ``self.check_list``.
Arguments:
- d (:obj:`Any`): The data which needs to be checked.
Returns:
- result (:obj:`bool`): Whether the data passes the check.
"""
# only the data passes all the check functions, would the check return True
return all([fn(d) for fn in self.check_list])
def _get_indices(self, size: int, sample_range: slice = None) -> list:
r"""
Overview:
Get the sample index list according to the priority probability.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled
Returns:
- index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
"""
# Divide [0, 1) into size intervals on average
intervals = np.array([i * 1.0 / size for i in range(size)])
# Uniformly sample within each interval
mass = intervals + np.random.uniform(size=(size, )) * 1. / size
if sample_range is None:
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass *= self._sum_tree.reduce()
else:
# Rescale to [a, b)
start = to_positive_index(sample_range.start, self._replay_buffer_size)
end = to_positive_index(sample_range.stop, self._replay_buffer_size)
a = self._sum_tree.reduce(0, start)
b = self._sum_tree.reduce(0, end)
mass = mass * (b - a) + a
# Find prefix sum index to sample with probability
return [self._sum_tree.find_prefixsum_idx(m) for m in mass]
def _remove(self, idx: int, use_too_many_times: bool = False) -> None:
r"""
Overview:
Remove a data(set the element in the list to ``None``) and update corresponding variables,
e.g. sum_tree, min_tree, valid_count.
Arguments:
- idx (:obj:`int`): Data at this position will be removed.
"""
if use_too_many_times:
if self._enable_track_used_data:
# Must track this data, but in parallel mode.
# Do not remove it, but make sure it will not be sampled.
self._data[idx]['priority'] = 0
self._sum_tree[idx] = self._sum_tree.neutral_element
self._min_tree[idx] = self._min_tree.neutral_element
return
elif idx == self._head:
# Correct `self._head` when the queue head is removed due to use_count
self._head = (self._head + 1) % self._replay_buffer_size
if self._data[idx] is not None:
if self._enable_track_used_data:
self._used_data_remover.add_used_data(self._data[idx])
self._valid_count -= 1
if self._rank == 0:
self._periodic_thruput_monitor.valid_count = self._valid_count
self._periodic_thruput_monitor.remove_data_count += 1
self._data[idx] = None
self._sum_tree[idx] = self._sum_tree.neutral_element
self._min_tree[idx] = self._min_tree.neutral_element
self._use_count[idx] = 0
def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list:
r"""
Overview:
Sample data with ``indices``; Remove a data item if it is used for too many times.
Arguments:
- indices (:obj:`List[int]`): A list including all the sample indices.
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
Returns:
- data (:obj:`list`) Sampled data.
"""
# Calculate max weight for normalizing IS
sum_tree_root = self._sum_tree.reduce()
p_min = self._min_tree.reduce() / sum_tree_root
max_weight = (self._valid_count * p_min) ** (-self._beta)
data = []
for idx in indices:
assert self._data[idx] is not None
assert self._data[idx]['replay_buffer_idx'] == idx, (self._data[idx]['replay_buffer_idx'], idx)
if self._deepcopy:
copy_data = copy.deepcopy(self._data[idx])
else:
copy_data = self._data[idx]
# Store staleness, use and IS(importance sampling weight for gradient step) for monitor and outer use
self._use_count[idx] += 1
copy_data['staleness'] = self._calculate_staleness(idx, cur_learner_iter)
copy_data['use'] = self._use_count[idx]
p_sample = self._sum_tree[idx] / sum_tree_root
weight = (self._valid_count * p_sample) ** (-self._beta)
copy_data['IS'] = weight / max_weight
data.append(copy_data)
if self._max_use != float("inf"):
# Remove datas whose "use count" is greater than ``max_use``
for idx in indices:
if self._use_count[idx] >= self._max_use:
self._remove(idx, use_too_many_times=True)
# Beta annealing
if self._anneal_step != 0:
self._beta = min(1.0, self._beta + self._beta_anneal_step)
return data
def _monitor_update_of_push(self, add_count: int, cur_collector_envstep: int = -1) -> None:
r"""
Overview:
Update values in monitor, then update text logger and tensorboard logger.
Called in ``_append`` and ``_extend``.
Arguments:
- add_count (:obj:`int`): How many datas are added into buffer.
- cur_collector_envstep (:obj:`int`): Collector envstep, passed in by collector.
"""
if self._rank == 0:
self._periodic_thruput_monitor.push_data_count += add_count
if self._use_thruput_controller:
self._thruput_controller.history_push_count += add_count
self._cur_collector_envstep = cur_collector_envstep
def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) -> None:
r"""
Overview:
Update values in monitor, then update text logger and tensorboard logger.
Called in ``sample``.
Arguments:
- sample_data (:obj:`list`): Sampled data. Used to get sample length and data's attributes, \
e.g. use, priority, staleness, etc.
- cur_learner_iter (:obj:`int`): Learner iteration, passed in by learner.
"""
if self._rank == 0:
self._periodic_thruput_monitor.sample_data_count += len(sample_data)
if self._use_thruput_controller:
self._thruput_controller.history_sample_count += len(sample_data)
self._cur_learner_iter = cur_learner_iter
use_avg = sum([d['use'] for d in sample_data]) / len(sample_data)
use_max = max([d['use'] for d in sample_data])
priority_avg = sum([d['priority'] for d in sample_data]) / len(sample_data)
priority_max = max([d['priority'] for d in sample_data])
priority_min = min([d['priority'] for d in sample_data])
staleness_avg = sum([d['staleness'] for d in sample_data]) / len(sample_data)
staleness_max = max([d['staleness'] for d in sample_data])
self._sampled_data_attr_monitor.use_avg = use_avg
self._sampled_data_attr_monitor.use_max = use_max
self._sampled_data_attr_monitor.priority_avg = priority_avg
self._sampled_data_attr_monitor.priority_max = priority_max
self._sampled_data_attr_monitor.priority_min = priority_min
self._sampled_data_attr_monitor.staleness_avg = staleness_avg
self._sampled_data_attr_monitor.staleness_max = staleness_max
self._sampled_data_attr_monitor.time.step()
out_dict = {
'use_avg': self._sampled_data_attr_monitor.avg['use'](),
'use_max': self._sampled_data_attr_monitor.max['use'](),
'priority_avg': self._sampled_data_attr_monitor.avg['priority'](),
'priority_max': self._sampled_data_attr_monitor.max['priority'](),
'priority_min': self._sampled_data_attr_monitor.min['priority'](),
'staleness_avg': self._sampled_data_attr_monitor.avg['staleness'](),
'staleness_max': self._sampled_data_attr_monitor.max['staleness'](),
'beta': self._beta,
}
if self._sampled_data_attr_print_count % self._sampled_data_attr_print_freq == 0 and self._rank == 0:
self._logger.info("=== Sample data {} Times ===".format(self._sampled_data_attr_print_count))
self._logger.info(self._logger.get_tabulate_vars_hor(out_dict))
for k, v in out_dict.items():
iter_metric = self._cur_learner_iter if self._cur_learner_iter != -1 else None
step_metric = self._cur_collector_envstep if self._cur_collector_envstep != -1 else None
if iter_metric is not None:
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, iter_metric)
if step_metric is not None:
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, step_metric)
self._sampled_data_attr_print_count += 1
def _calculate_staleness(self, pos_index: int, cur_learner_iter: int) -> Optional[int]:
r"""
Overview:
Calculate a data's staleness according to its own attribute ``collect_iter``
and input parameter ``cur_learner_iter``.
Arguments:
- pos_index (:obj:`int`): The position index. Staleness of the data at this index will be calculated.
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
Returns:
- staleness (:obj:`int`): Staleness of data at position ``pos_index``.
.. note::
Caller should guarantee that data at ``pos_index`` is not None; Otherwise this function may raise an error.
"""
if self._data[pos_index] is None:
raise ValueError("Prioritized's data at index {} is None".format(pos_index))
else:
# Calculate staleness, remove it if too stale
collect_iter = self._data[pos_index].get('collect_iter', cur_learner_iter + 1)
if isinstance(collect_iter, list):
# Timestep transition's collect_iter is a list
collect_iter = min(collect_iter)
# ``staleness`` might be -1, means invalid, e.g. collector does not report collecting model iter,
# or it is a demonstration buffer(which means data is not generated by collector) etc.
staleness = cur_learner_iter - collect_iter
return staleness
def count(self) -> int:
"""
Overview:
Count how many valid datas there are in the buffer.
Returns:
- count (:obj:`int`): Number of valid data.
"""
return self._valid_count
@property
def beta(self) -> float:
return self._beta
@beta.setter
def beta(self, beta: float) -> None:
self._beta = beta
def state_dict(self) -> dict:
"""
Overview:
Provide a state dict to keep a record of current buffer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \
With the dict, one can easily reproduce the buffer.
"""
return {
'data': self._data,
'use_count': self._use_count,
'tail': self._tail,
'max_priority': self._max_priority,
'anneal_step': self._anneal_step,
'beta': self._beta,
'head': self._head,
'next_unique_id': self._next_unique_id,
'valid_count': self._valid_count,
'push_count': self._push_count,
'sum_tree': self._sum_tree,
'min_tree': self._min_tree,
}
def load_state_dict(self, _state_dict: dict, deepcopy: bool = False) -> None:
"""
Overview:
Load state dict to reproduce the buffer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer.
"""
assert 'data' in _state_dict
if set(_state_dict.keys()) == set(['data']):
self._extend(_state_dict['data'])
else:
for k, v in _state_dict.items():
if deepcopy:
setattr(self, '_{}'.format(k), copy.deepcopy(v))
else:
setattr(self, '_{}'.format(k), v)
@property
def replay_buffer_size(self) -> int:
return self._replay_buffer_size
@property
def push_count(self) -> int:
return self._push_count