import math from functools import partial, lru_cache from typing import Optional, Dict, Any import numpy as np import torch from ding.compatibility import torch_ge_180 from ding.torch_utils import one_hot num_first_one_hot = partial(one_hot, num_first=True) def sqrt_one_hot(v: torch.Tensor, max_val: int) -> torch.Tensor: """ Overview: Sqrt the input value ``v`` and transform it into one-hot. Arguments: - v (:obj:`torch.Tensor`): the value to be processed with `sqrt` and `one-hot` - max_val (:obj:`int`): the input ``v``'s estimated max value, used to calculate one-hot bit number. \ ``v`` would be clamped by (0, max_val). Returns: - ret (:obj:`torch.Tensor`): the value processed after `sqrt` and `one-hot` """ num = int(math.sqrt(max_val)) + 1 v = v.float() v = torch.floor(torch.sqrt(torch.clamp(v, 0, max_val))).long() return one_hot(v, num) def div_one_hot(v: torch.Tensor, max_val: int, ratio: int) -> torch.Tensor: """ Overview: Divide the input value ``v`` by ``ratio`` and transform it into one-hot. Arguments: - v (:obj:`torch.Tensor`): the value to be processed with `divide` and `one-hot` - max_val (:obj:`int`): the input ``v``'s estimated max value, used to calculate one-hot bit number. \ ``v`` would be clamped by (0, ``max_val``). - ratio (:obj:`int`): input ``v`` would be divided by ``ratio`` Returns: - ret (:obj:`torch.Tensor`): the value processed after `divide` and `one-hot` """ num = int(max_val / ratio) + 1 v = v.float() v = torch.floor(torch.clamp(v, 0, max_val) / ratio).long() return one_hot(v, num) def div_func(inputs: torch.Tensor, other: float, unsqueeze_dim: int = 1): """ Overview: Divide ``inputs`` by ``other`` and unsqueeze if needed. Arguments: - inputs (:obj:`torch.Tensor`): the value to be unsqueezed and divided - other (:obj:`float`): input would be divided by ``other`` - unsqueeze_dim (:obj:`int`): the dim to implement unsqueeze Returns: - ret (:obj:`torch.Tensor`): the value processed after `unsqueeze` and `divide` """ inputs = inputs.float() if unsqueeze_dim is not None: inputs = inputs.unsqueeze(unsqueeze_dim) return torch.div(inputs, other) def clip_one_hot(v: torch.Tensor, num: int) -> torch.Tensor: """ Overview: Clamp the input ``v`` in (0, num-1) and make one-hot mapping. Arguments: - v (:obj:`torch.Tensor`): the value to be processed with `clamp` and `one-hot` - num (:obj:`int`): number of one-hot bits Returns: - ret (:obj:`torch.Tensor`): the value processed after `clamp` and `one-hot` """ v = v.clamp(0, num - 1) return one_hot(v, num) def reorder_one_hot( v: torch.LongTensor, dictionary: Dict[int, int], num: int, transform: Optional[np.ndarray] = None ) -> torch.Tensor: """ Overview: Reorder each value in input ``v`` according to reorder dict ``dictionary``, then make one-hot mapping Arguments: - v (:obj:`torch.LongTensor`): the original value to be processed with `reorder` and `one-hot` - dictionary (:obj:`Dict[int, int]`): a reorder lookup dict, \ map original value to new reordered index starting from 0 - num (:obj:`int`): number of one-hot bits - transform (:obj:`int`): an array to firstly transform the original action to general action Returns: - ret (:obj:`torch.Tensor`): one-hot data indicating reordered index """ assert (len(v.shape) == 1) assert (isinstance(v, torch.Tensor)) new_v = torch.zeros_like(v) for idx in range(v.shape[0]): if transform is None: val = v[idx].item() else: val = transform[v[idx].item()] new_v[idx] = dictionary[val] return one_hot(new_v, num) def reorder_one_hot_array( v: torch.LongTensor, array: np.ndarray, num: int, transform: Optional[np.ndarray] = None ) -> torch.Tensor: """ Overview: Reorder each value in input ``v`` according to reorder dict ``dictionary``, then make one-hot mapping. The difference between this function and ``reorder_one_hot`` is whether the type of reorder lookup data structure is `np.ndarray` or `dict`. Arguments: - v (:obj:`torch.LongTensor`): the value to be processed with `reorder` and `one-hot` - array (:obj:`np.ndarray`): a reorder lookup array, map original value to new reordered index starting from 0 - num (:obj:`int`): number of one-hot bits - transform (:obj:`np.ndarray`): an array to firstly transform the original action to general action Returns: - ret (:obj:`torch.Tensor`): one-hot data indicating reordered index """ v = v.numpy() if transform is None: val = array[v] else: val = array[transform[v]] return one_hot(torch.LongTensor(val), num) def reorder_boolean_vector( v: torch.LongTensor, dictionary: Dict[int, int], num: int, transform: Optional[np.ndarray] = None ) -> torch.Tensor: """ Overview: Reorder each value in input ``v`` to new index according to reorder dict ``dictionary``, then set corresponding position in return tensor to 1. Arguments: - v (:obj:`torch.LongTensor`): the value to be processed with `reorder` - dictionary (:obj:`Dict[int, int]`): a reorder lookup dict, \ map original value to new reordered index starting from 0 - num (:obj:`int`): total number of items, should equals to max index + 1 - transform (:obj:`np.ndarray`): an array to firstly transform the original action to general action Returns: - ret (:obj:`torch.Tensor`): boolean data containing only 0 and 1, \ indicating whether corresponding original value exists in input ``v`` """ ret = torch.zeros(num) for item in v: try: if transform is None: val = item.item() else: val = transform[item.item()] idx = dictionary[val] except KeyError as e: # print(dictionary) raise KeyError('{}_{}_'.format(num, e)) ret[idx] = 1 return ret @lru_cache(maxsize=32) def get_to_and(num_bits: int) -> np.ndarray: """ Overview: Get an np.ndarray with ``num_bits`` elements, each equals to :math:`2^n` (n decreases from num_bits-1 to 0). Used by ``batch_binary_encode`` to make bit-wise `and`. Arguments: - num_bits (:obj:`int`): length of the generating array Returns: - to_and (:obj:`np.ndarray`): an array with ``num_bits`` elements, \ each equals to :math:`2^n` (n decreases from num_bits-1 to 0) """ return 2 ** np.arange(num_bits - 1, -1, -1).reshape([1, num_bits]) def batch_binary_encode(x: torch.Tensor, bit_num: int) -> torch.Tensor: """ Overview: Big endian binary encode ``x`` to float tensor Arguments: - x (:obj:`torch.Tensor`): the value to be unsqueezed and divided - bit_num (:obj:`int`): number of bits, should satisfy :math:`2^{bit num} > max(x)` Example: >>> batch_binary_encode(torch.tensor([131,71]), 10) tensor([[0., 0., 1., 0., 0., 0., 0., 0., 1., 1.], [0., 0., 0., 1., 0., 0., 0., 1., 1., 1.]]) Returns: - ret (:obj:`torch.Tensor`): the binary encoded tensor, containing only `0` and `1` """ x = x.numpy() xshape = list(x.shape) x = x.reshape([-1, 1]) to_and = get_to_and(bit_num) return torch.FloatTensor((x & to_and).astype(bool).astype(float).reshape(xshape + [bit_num])) def compute_denominator(x: torch.Tensor) -> torch.Tensor: """ Overview: Compute the denominator used in ``get_postion_vector``. \ Divide 1 at the last step, so you can use it as an multiplier. Arguments: - x (:obj:`torch.Tensor`): Input tensor, which is generated from torch.arange(0, d_model). Returns: - ret (:obj:`torch.Tensor`): Denominator result tensor. """ if torch_ge_180(): x = torch.div(x, 2, rounding_mode='trunc') * 2 else: x = torch.div(x, 2) * 2 x = torch.div(x, 64.) x = torch.pow(10000., x) x = torch.div(1., x) return x def get_postion_vector(x: list) -> torch.Tensor: """ Overview: Get position embedding used in `Transformer`, even and odd :math:`\alpha` are stored in ``POSITION_ARRAY`` Arguments: - x (:obj:`list`): original position index, whose length should be 32 Returns: - v (:obj:`torch.Tensor`): position embedding tensor in 64 dims """ # TODO use lru_cache to optimize it POSITION_ARRAY = compute_denominator(torch.arange(0, 64, dtype=torch.float)) # d_model = 64 v = torch.zeros(64, dtype=torch.float) x = torch.FloatTensor(x) v[0::2] = torch.sin(x * POSITION_ARRAY[0::2]) # even v[1::2] = torch.cos(x * POSITION_ARRAY[1::2]) # odd return v def affine_transform( data: Any, action_clip: Optional[bool] = True, alpha: Optional[float] = None, beta: Optional[float] = None, min_val: Optional[float] = None, max_val: Optional[float] = None ) -> Any: """ Overview: do affine transform for data in range [-1, 1], :math:`\alpha \times data + \beta` Arguments: - data (:obj:`Any`): the input data - action_clip (:obj:`bool`): whether to do action clip operation ([-1, 1]) - alpha (:obj:`float`): affine transform weight - beta (:obj:`float`): affine transform bias - min_val (:obj:`float`): min value, if `min_val` and `max_val` are indicated, scale input data\ to [min_val, max_val] - max_val (:obj:`float`): max value Returns: - transformed_data (:obj:`Any`): affine transformed data """ if action_clip: data = np.clip(data, -1, 1) if min_val is not None: assert max_val is not None alpha = (max_val - min_val) / 2 beta = (max_val + min_val) / 2 assert alpha is not None beta = beta if beta is not None else 0. return data * alpha + beta def save_frames_as_gif(frames: list, path: str) -> None: """ Overview: save frames as gif to a specified path. Arguments: - frames (:obj:`List`): list of frames - path (:obj:`str`): the path to save gif """ try: import imageio except ImportError: from ditk import logging import sys logging.warning("Please install imageio first.") sys.exit(1) imageio.mimsave(path, frames, fps=20)