gomoku / DI-engine /ding /utils /default_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
21.8 kB
from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict
import copy
from ditk import logging
import random
from functools import lru_cache # in python3.9, we can change to cache
import numpy as np
import torch
import treetensor.torch as ttorch
def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int:
"""
Overview:
Get shape[0] of data's torch tensor or treetensor
Arguments:
- data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed
Returns:
- shape[0] (:obj:`int`): first dimension length of data, usually the batchsize.
"""
if isinstance(data, list) or isinstance(data, tuple):
return get_shape0(data[0])
elif isinstance(data, dict):
for k, v in data.items():
return get_shape0(v)
elif isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):
def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)
return fn(data.shape)
else:
raise TypeError("Error in getting shape0, not support type: {}".format(data))
def lists_to_dicts(
data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]],
recursive: bool = False,
) -> Union[Mapping[object, object], NamedTuple]:
"""
Overview:
Transform a list of dicts to a dict of lists.
Arguments:
- data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`):
A dict of lists need to be transformed
- recursive (:obj:`bool`): whether recursively deals with dict element
Returns:
- newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result
Example:
>>> from ding.utils import *
>>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
{1: [1, 2], 10: [3, 4]}
"""
if len(data) == 0:
raise ValueError("empty data")
if isinstance(data[0], dict):
if recursive:
new_data = {}
for k in data[0].keys():
if isinstance(data[0][k], dict) and k != 'prev_state':
tmp = [data[b][k] for b in range(len(data))]
new_data[k] = lists_to_dicts(tmp)
else:
new_data[k] = [data[b][k] for b in range(len(data))]
else:
new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()}
elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple
new_data = type(data[0])(*list(zip(*data)))
else:
raise TypeError("not support element type: {}".format(type(data[0])))
return new_data
def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]:
"""
Overview:
Transform a dict of lists to a list of dicts.
Arguments:
- data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed
Returns:
- newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result
Example:
>>> from ding.utils import *
>>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
[{1: 1, 10: 3}, {1: 2, 10: 4}]
"""
new_data = [v for v in data.values()]
new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))]
return new_data
def override(cls: type) -> Callable[[
Callable,
], Callable]:
"""
Overview:
Annotation for documenting method overrides.
Arguments:
- cls (:obj:`type`): The superclass that provides the overridden method. If this
cls does not actually have the method, an error is raised.
"""
def check_override(method: Callable) -> Callable:
if method.__name__ not in dir(cls):
raise NameError("{} does not override any method of {}".format(method, cls))
return method
return check_override
def squeeze(data: object) -> object:
"""
Overview:
Squeeze data from tuple, list or dict to single object
Arguments:
- data (:obj:`object`): data to be squeezed
Example:
>>> a = (4, )
>>> a = squeeze(a)
>>> print(a)
>>> 4
"""
if isinstance(data, tuple) or isinstance(data, list):
if len(data) == 1:
return data[0]
else:
return tuple(data)
elif isinstance(data, dict):
if len(data) == 1:
return list(data.values())[0]
return data
default_get_set = set()
def default_get(
data: dict,
name: str,
default_value: Optional[Any] = None,
default_fn: Optional[Callable] = None,
judge_fn: Optional[Callable] = None
) -> Any:
"""
Overview:
Getting the value by input, checks generically on the inputs with \
at least ``data`` and ``name``. If ``name`` exists in ``data``, \
get the value at ``name``; else, add ``name`` to ``default_get_set``\
with value generated by \
``default_fn`` (or directly as ``default_value``) that \
is checked by `` judge_fn`` to be legal.
Arguments:
- data(:obj:`dict`): Data input dictionary
- name(:obj:`str`): Key name
- default_value(:obj:`Optional[Any]`) = None,
- default_fn(:obj:`Optional[Callable]`) = Value
- judge_fn(:obj:`Optional[Callable]`) = None
Returns:
- ret(:obj:`list`): Splitted data
- residual(:obj:`list`): Residule list
"""
if name in data:
return data[name]
else:
assert default_value is not None or default_fn is not None
value = default_fn() if default_fn is not None else default_value
if judge_fn:
assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value))
if name not in default_get_set:
logging.warning("{} use default value {}".format(name, value))
default_get_set.add(name)
return value
def list_split(data: list, step: int) -> List[list]:
"""
Overview:
Split list of data by step.
Arguments:
- data(:obj:`list`): List of data for spliting
- step(:obj:`int`): Number of step for spliting
Returns:
- ret(:obj:`list`): List of splitted data.
- residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``.
Example:
>>> list_split([1,2,3,4],2)
([[1, 2], [3, 4]], None)
>>> list_split([1,2,3,4],3)
([[1, 2, 3]], [4])
"""
if len(data) < step:
return [], data
ret = []
divide_num = len(data) // step
for i in range(divide_num):
start, end = i * step, (i + 1) * step
ret.append(data[start:end])
if divide_num * step < len(data):
residual = data[divide_num * step:]
else:
residual = None
return ret, residual
def error_wrapper(fn, default_ret, warning_msg=""):
"""
Overview:
wrap the function, so that any Exception in the function will be catched and return the default_ret
Arguments:
- fn (:obj:`Callable`): the function to be wraped
- default_ret (:obj:`obj`): the default return when an Exception occurred in the function
Returns:
- wrapper (:obj:`Callable`): the wrapped function
Examples:
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
>>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink.
>>> if is_fake_link:
>>> return 0
>>> return error_wrapper(link.get_rank, 0)()
"""
def wrapper(*args, **kwargs):
try:
ret = fn(*args, **kwargs)
except Exception as e:
ret = default_ret
if warning_msg != "":
one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e))
return ret
return wrapper
class LimitedSpaceContainer:
"""
Overview:
A space simulator.
Interfaces:
``__init__``, ``get_residual_space``, ``release_space``
"""
def __init__(self, min_val: int, max_val: int) -> None:
"""
Overview:
Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization.
Arguments:
- min_val (:obj:`int`): Min volume of the container, usually 0.
- max_val (:obj:`int`): Max volume of the container.
"""
self.min_val = min_val
self.max_val = max_val
assert (max_val >= min_val)
self.cur = self.min_val
def get_residual_space(self) -> int:
"""
Overview:
Get all residual pieces of space. Set ``cur`` to ``max_val``
Arguments:
- ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``.
"""
ret = self.max_val - self.cur
self.cur = self.max_val
return ret
def acquire_space(self) -> bool:
"""
Overview:
Try to get one pice of space. If there is one, return True; Otherwise return False.
Returns:
- flag (:obj:`bool`): Whether there is any piece of residual space.
"""
if self.cur < self.max_val:
self.cur += 1
return True
else:
return False
def release_space(self) -> None:
"""
Overview:
Release only one piece of space. Decrement ``cur``, but ensure it won't be negative.
"""
self.cur = max(self.min_val, self.cur - 1)
def increase_space(self) -> None:
"""
Overview:
Increase one piece in space. Increment ``max_val``.
"""
self.max_val += 1
def decrease_space(self) -> None:
"""
Overview:
Decrease one piece in space. Decrement ``max_val``.
"""
self.max_val -= 1
def deep_merge_dicts(original: dict, new_dict: dict) -> dict:
"""
Overview:
Merge two dicts by calling ``deep_update``
Arguments:
- original (:obj:`dict`): Dict 1.
- new_dict (:obj:`dict`): Dict 2.
Returns:
- merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged.
"""
original = original or {}
new_dict = new_dict or {}
merged = copy.deepcopy(original)
if new_dict: # if new_dict is neither empty dict nor None
deep_update(merged, new_dict, True, [])
return merged
def deep_update(
original: dict,
new_dict: dict,
new_keys_allowed: bool = False,
whitelist: Optional[List[str]] = None,
override_all_if_type_changes: Optional[List[str]] = None
):
"""
Overview:
Update original dict with values from new_dict recursively.
Arguments:
- original (:obj:`dict`): Dictionary with default values.
- new_dict (:obj:`dict`): Dictionary with values to be updated
- new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
- whitelist (:obj:`Optional[List[str]]`):
List of keys that correspond to dict
values where new subkeys can be introduced. This is only at the top
level.
- override_all_if_type_changes(:obj:`Optional[List[str]]`):
List of top level
keys with value=dict, for which we always simply override the
entire value (:obj:`dict`), if the "type" key in that value dict changes.
.. note::
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
"""
whitelist = whitelist or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
# Both original value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Whitelisted key -> ok to add new subkeys.
elif k in whitelist:
deep_update(original[k], value, True)
# Non-whitelisted key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
def flatten_dict(data: dict, delimiter: str = "/") -> dict:
"""
Overview:
Flatten the dict, see example
Arguments:
- data (:obj:`dict`): Original nested dict
- delimiter (str): Delimiter of the keys of the new dict
Returns:
- data (:obj:`dict`): Flattened nested dict
Example:
>>> a
{'a': {'b': 100}}
>>> flatten_dict(a)
{'a/b': 100}
"""
data = copy.deepcopy(data)
while any(isinstance(v, dict) for v in data.values()):
remove = []
add = {}
for key, value in data.items():
if isinstance(value, dict):
for subkey, v in value.items():
add[delimiter.join([key, subkey])] = v
remove.append(key)
data.update(add)
for k in remove:
del data[k]
return data
def set_pkg_seed(seed: int, use_cuda: bool = True) -> None:
"""
Overview:
Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\
This is usaually used in entry scipt in the section of setting random seed for all package and instance
Argument:
- seed(:obj:`int`): Set seed
- use_cuda(:obj:`bool`) Whether use cude
Examples:
>>> # ../entry/xxxenv_xxxpolicy_main.py
>>> ...
# Set random seed for all package and instance
>>> collector_env.seed(seed)
>>> evaluator_env.seed(seed, dynamic_seed=False)
>>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
>>> ...
# Set up RL Policy, etc.
>>> ...
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
@lru_cache()
def one_time_warning(warning_msg: str) -> None:
"""
Overview:
Print warning message only once.
Arguments:
- warning_msg (:obj:`str`): Warning message.
"""
logging.warning(warning_msg)
def split_fn(data, indices, start, end):
"""
Overview:
Split data by indices
Arguments:
- data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed
- indices (:obj:`np.ndarray`): indices to split
- start (:obj:`int`): start index
- end (:obj:`int`): end index
"""
if data is None:
return None
elif isinstance(data, list):
return [split_fn(d, indices, start, end) for d in data]
elif isinstance(data, dict):
return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()}
elif isinstance(data, str):
return data
else:
return data[indices[start:end]]
def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict:
"""
Overview:
Split data into batches
Arguments:
- data (:obj:`dict`): data to be analysed
- split_size (:obj:`int`): split size
- shuffle (:obj:`bool`): whether shuffle
"""
assert isinstance(data, dict), type(data)
length = []
for k, v in data.items():
if v is None:
continue
elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']:
length.append(len(v))
elif isinstance(v, list) or isinstance(v, tuple):
if isinstance(v[0], str):
# some buffer data contains useless string infos, such as 'buffer_id',
# which should not be split, so we just skip it
continue
else:
length.append(get_shape0(v[0]))
elif isinstance(v, dict):
length.append(len(v[list(v.keys())[0]]))
else:
length.append(len(v))
assert len(length) > 0
# assert len(set(length)) == 1, "data values must have the same length: {}".format(length)
# if continuous action, data['logit'] is list of length 2
length = length[0]
assert split_size >= 1
if shuffle:
indices = np.random.permutation(length)
else:
indices = np.arange(length)
for i in range(0, length, split_size):
if i + split_size > length:
i = length - split_size
batch = split_fn(data, indices, i, i + split_size)
yield batch
class RunningMeanStd(object):
"""
Overview:
Wrapper to update new variable, new mean, and new count
Interfaces:
``__init__``, ``update``, ``reset``, ``new_shape``
Properties:
- ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count``
"""
def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')):
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate \
signature; setup the properties.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
- epsilon (:obj:`Float`): the epsilon used for self for the std output
- shape (:obj: `np.array`): the np array shape used for the expression \
of this wrapper on attibutes of mean and variance
"""
self._epsilon = epsilon
self._shape = shape
self._device = device
self.reset()
def update(self, x):
"""
Overview:
Update mean, variable, and count
Arguments:
- ``x``: the batch
"""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
new_count = batch_count + self._count
mean_delta = batch_mean - self._mean
new_mean = self._mean + mean_delta * batch_count / new_count
# this method for calculating new variable might be numerically unstable
m_a = self._var * self._count
m_b = batch_var * batch_count
m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count
new_var = m2 / new_count
self._mean = new_mean
self._var = new_var
self._count = new_count
def reset(self):
"""
Overview:
Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count``
"""
if len(self._shape) > 0:
self._mean = np.zeros(self._shape, 'float32')
self._var = np.ones(self._shape, 'float32')
else:
self._mean, self._var = 0., 1.
self._count = self._epsilon
@property
def mean(self) -> np.ndarray:
"""
Overview:
Property ``mean`` gotten from ``self._mean``
"""
if np.isscalar(self._mean):
return self._mean
else:
return torch.FloatTensor(self._mean).to(self._device)
@property
def std(self) -> np.ndarray:
"""
Overview:
Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon``
"""
std = np.sqrt(self._var + 1e-8)
if np.isscalar(std):
return std
else:
return torch.FloatTensor(std).to(self._device)
@staticmethod
def new_shape(obs_shape, act_shape, rew_shape):
"""
Overview:
Get new shape of observation, acton, and reward; in this case unchanged.
Arguments:
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
Returns:
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
"""
return obs_shape, act_shape, rew_shape
def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Make the key of dict into legal python identifier string so that it is
compatible with some python magic method such as ``__getattr``.
Arguments:
- data (:obj:`Dict[str, Any]`): The original dict data.
Return:
- new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys.
"""
def legalization(s: str) -> str:
if s[0].isdigit():
s = '_' + s
return s.replace('.', '_')
new_data = {}
for k in data:
new_k = legalization(k)
new_data[new_k] = data[k]
return new_data
def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Remove illegal item in dict info, like str, which is not compatible with Tensor.
Arguments:
- data (:obj:`Dict[str, Any]`): The original dict data.
Return:
- new_data (:obj:`Dict[str, Any]`): The new dict data without legal items.
"""
new_data = {}
for k, v in data.items():
if isinstance(v, str):
continue
new_data[k] = data[k]
return new_data