|
from typing import Iterable, Any, Optional, List |
|
from collections.abc import Sequence |
|
import numbers |
|
import time |
|
import copy |
|
from threading import Thread |
|
from queue import Queue |
|
|
|
import numpy as np |
|
import torch |
|
import treetensor.torch as ttorch |
|
|
|
from ding.utils.default_helper import get_shape0 |
|
|
|
|
|
def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: |
|
""" |
|
Overview: |
|
Transfer data to certain device. |
|
Arguments: |
|
- item (:obj:`Any`): The item to be transferred. |
|
- device (:obj:`str`): The device wanted. |
|
- ignore_keys (:obj:`list`): The keys to be ignored in transfer, default set to empty. |
|
Returns: |
|
- item (:obj:`Any`): The transferred item. |
|
Examples: |
|
>>> setup_data_dict['module'] = nn.Linear(3, 5) |
|
>>> device = 'cuda' |
|
>>> cuda_d = to_device(setup_data_dict, device, ignore_keys=['module']) |
|
>>> assert cuda_d['module'].weight.device == torch.device('cpu') |
|
|
|
Examples: |
|
>>> setup_data_dict['module'] = nn.Linear(3, 5) |
|
>>> device = 'cuda' |
|
>>> cuda_d = to_device(setup_data_dict, device) |
|
>>> assert cuda_d['module'].weight.device == torch.device('cuda:0') |
|
|
|
.. note: |
|
|
|
Now supports item type: :obj:`torch.nn.Module`, :obj:`torch.Tensor`, :obj:`Sequence`, \ |
|
:obj:`dict`, :obj:`numbers.Integral`, :obj:`numbers.Real`, :obj:`np.ndarray`, :obj:`str` and :obj:`None`. |
|
|
|
""" |
|
if isinstance(item, torch.nn.Module): |
|
return item.to(device) |
|
elif isinstance(item, ttorch.Tensor): |
|
if 'prev_state' in item: |
|
prev_state = to_device(item.prev_state, device) |
|
del item.prev_state |
|
item = item.to(device) |
|
item.prev_state = prev_state |
|
return item |
|
else: |
|
return item.to(device) |
|
elif isinstance(item, torch.Tensor): |
|
return item.to(device) |
|
elif isinstance(item, Sequence): |
|
if isinstance(item, str): |
|
return item |
|
else: |
|
return [to_device(t, device) for t in item] |
|
elif isinstance(item, dict): |
|
new_item = {} |
|
for k in item.keys(): |
|
if k in ignore_keys: |
|
new_item[k] = item[k] |
|
else: |
|
new_item[k] = to_device(item[k], device) |
|
return new_item |
|
elif isinstance(item, numbers.Integral) or isinstance(item, numbers.Real): |
|
return item |
|
elif isinstance(item, np.ndarray) or isinstance(item, np.bool_): |
|
return item |
|
elif item is None or isinstance(item, str): |
|
return item |
|
elif isinstance(item, torch.distributions.Distribution): |
|
return item |
|
else: |
|
raise TypeError("not support item type: {}".format(type(item))) |
|
|
|
|
|
def to_dtype(item: Any, dtype: type) -> Any: |
|
""" |
|
Overview: |
|
Change data to certain dtype. |
|
Arguments: |
|
- item (:obj:`Any`): The item for changing the dtype. |
|
- dtype (:obj:`type`): The type wanted. |
|
Returns: |
|
- item (:obj:`object`): The item with changed dtype. |
|
Examples (tensor): |
|
>>> t = torch.randint(0, 10, (3, 5)) |
|
>>> tfloat = to_dtype(t, torch.float) |
|
>>> assert tfloat.dtype == torch.float |
|
|
|
Examples (list): |
|
>>> tlist = [torch.randint(0, 10, (3, 5))] |
|
>>> tlfloat = to_dtype(tlist, torch.float) |
|
>>> assert tlfloat[0].dtype == torch.float |
|
|
|
Examples (dict): |
|
>>> tdict = {'t': torch.randint(0, 10, (3, 5))} |
|
>>> tdictf = to_dtype(tdict, torch.float) |
|
>>> assert tdictf['t'].dtype == torch.float |
|
|
|
.. note: |
|
|
|
Now supports item type: :obj:`torch.Tensor`, :obj:`Sequence`, :obj:`dict`. |
|
""" |
|
if isinstance(item, torch.Tensor): |
|
return item.to(dtype=dtype) |
|
elif isinstance(item, Sequence): |
|
return [to_dtype(t, dtype) for t in item] |
|
elif isinstance(item, dict): |
|
return {k: to_dtype(item[k], dtype) for k in item.keys()} |
|
else: |
|
raise TypeError("not support item type: {}".format(type(item))) |
|
|
|
|
|
def to_tensor( |
|
item: Any, dtype: Optional[torch.dtype] = None, ignore_keys: list = [], transform_scalar: bool = True |
|
) -> Any: |
|
""" |
|
Overview: |
|
Convert ``numpy.ndarray`` object to ``torch.Tensor``. |
|
Arguments: |
|
- item (:obj:`Any`): The ``numpy.ndarray`` objects to be converted. It can be exactly a ``numpy.ndarray`` \ |
|
object or a container (list, tuple or dict) that contains several ``numpy.ndarray`` objects. |
|
- dtype (:obj:`torch.dtype`): The type of wanted tensor. If set to ``None``, its dtype will be unchanged. |
|
- ignore_keys (:obj:`list`): If the ``item`` is a dict, values whose keys are in ``ignore_keys`` will not \ |
|
be converted. |
|
- transform_scalar (:obj:`bool`): If set to ``True``, a scalar will be also converted to a tensor object. |
|
Returns: |
|
- item (:obj:`Any`): The converted tensors. |
|
|
|
Examples (scalar): |
|
>>> i = 10 |
|
>>> t = to_tensor(i) |
|
>>> assert t.item() == i |
|
|
|
Examples (dict): |
|
>>> d = {'i': i} |
|
>>> dt = to_tensor(d, torch.int) |
|
>>> assert dt['i'].item() == i |
|
|
|
Examples (named tuple): |
|
>>> data_type = namedtuple('data_type', ['x', 'y']) |
|
>>> inputs = data_type(np.random.random(3), 4) |
|
>>> outputs = to_tensor(inputs, torch.float32) |
|
>>> assert type(outputs) == data_type |
|
>>> assert isinstance(outputs.x, torch.Tensor) |
|
>>> assert isinstance(outputs.y, torch.Tensor) |
|
>>> assert outputs.x.dtype == torch.float32 |
|
>>> assert outputs.y.dtype == torch.float32 |
|
|
|
.. note: |
|
|
|
Now supports item type: :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. |
|
""" |
|
|
|
def transform(d): |
|
if dtype is None: |
|
return torch.as_tensor(d) |
|
else: |
|
return torch.tensor(d, dtype=dtype) |
|
|
|
if isinstance(item, dict): |
|
new_data = {} |
|
for k, v in item.items(): |
|
if k in ignore_keys: |
|
new_data[k] = v |
|
else: |
|
new_data[k] = to_tensor(v, dtype, ignore_keys, transform_scalar) |
|
return new_data |
|
elif isinstance(item, list) or isinstance(item, tuple): |
|
if len(item) == 0: |
|
return [] |
|
elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real): |
|
return transform(item) |
|
elif hasattr(item, '_fields'): |
|
return type(item)(*[to_tensor(t, dtype) for t in item]) |
|
else: |
|
new_data = [] |
|
for t in item: |
|
new_data.append(to_tensor(t, dtype, ignore_keys, transform_scalar)) |
|
return new_data |
|
elif isinstance(item, np.ndarray): |
|
if dtype is None: |
|
if item.dtype == np.float64: |
|
return torch.FloatTensor(item) |
|
else: |
|
return torch.from_numpy(item) |
|
else: |
|
return torch.from_numpy(item).to(dtype) |
|
elif isinstance(item, bool) or isinstance(item, str): |
|
return item |
|
elif np.isscalar(item): |
|
if transform_scalar: |
|
if dtype is None: |
|
return torch.as_tensor(item) |
|
else: |
|
return torch.as_tensor(item).to(dtype) |
|
else: |
|
return item |
|
elif item is None: |
|
return None |
|
elif isinstance(item, torch.Tensor): |
|
if dtype is None: |
|
return item |
|
else: |
|
return item.to(dtype) |
|
else: |
|
raise TypeError("not support item type: {}".format(type(item))) |
|
|
|
|
|
def to_ndarray(item: Any, dtype: np.dtype = None) -> Any: |
|
""" |
|
Overview: |
|
Convert ``torch.Tensor`` to ``numpy.ndarray``. |
|
Arguments: |
|
- item (:obj:`Any`): The ``torch.Tensor`` objects to be converted. It can be exactly a ``torch.Tensor`` \ |
|
object or a container (list, tuple or dict) that contains several ``torch.Tensor`` objects. |
|
- dtype (:obj:`np.dtype`): The type of wanted array. If set to ``None``, its dtype will be unchanged. |
|
Returns: |
|
- item (:obj:`object`): The changed arrays. |
|
|
|
Examples (ndarray): |
|
>>> t = torch.randn(3, 5) |
|
>>> tarray1 = to_ndarray(t) |
|
>>> assert tarray1.shape == (3, 5) |
|
>>> assert isinstance(tarray1, np.ndarray) |
|
|
|
Examples (list): |
|
>>> t = [torch.randn(5, ) for i in range(3)] |
|
>>> tarray1 = to_ndarray(t, np.float32) |
|
>>> assert isinstance(tarray1, list) |
|
>>> assert tarray1[0].shape == (5, ) |
|
>>> assert isinstance(tarray1[0], np.ndarray) |
|
|
|
.. note: |
|
|
|
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. |
|
""" |
|
|
|
def transform(d): |
|
if dtype is None: |
|
return np.array(d) |
|
else: |
|
return np.array(d, dtype=dtype) |
|
|
|
if isinstance(item, dict): |
|
new_data = {} |
|
for k, v in item.items(): |
|
new_data[k] = to_ndarray(v, dtype) |
|
return new_data |
|
elif isinstance(item, list) or isinstance(item, tuple): |
|
if len(item) == 0: |
|
return None |
|
elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real): |
|
return transform(item) |
|
elif hasattr(item, '_fields'): |
|
return type(item)(*[to_ndarray(t, dtype) for t in item]) |
|
else: |
|
new_data = [] |
|
for t in item: |
|
new_data.append(to_ndarray(t, dtype)) |
|
return new_data |
|
elif isinstance(item, torch.Tensor): |
|
if dtype is None: |
|
return item.numpy() |
|
else: |
|
return item.numpy().astype(dtype) |
|
elif isinstance(item, np.ndarray): |
|
if dtype is None: |
|
return item |
|
else: |
|
return item.astype(dtype) |
|
elif isinstance(item, bool) or isinstance(item, str): |
|
return item |
|
elif np.isscalar(item): |
|
if dtype is None: |
|
return np.array(item) |
|
else: |
|
return np.array(item, dtype=dtype) |
|
elif item is None: |
|
return None |
|
else: |
|
raise TypeError("not support item type: {}".format(type(item))) |
|
|
|
|
|
def to_list(item: Any) -> Any: |
|
""" |
|
Overview: |
|
Convert ``torch.Tensor``, ``numpy.ndarray`` objects to ``list`` objects, and keep their dtypes unchanged. |
|
Arguments: |
|
- item (:obj:`Any`): The item to be converted. |
|
Returns: |
|
- item (:obj:`Any`): The list after conversion. |
|
|
|
Examples: |
|
>>> data = { \ |
|
'tensor': torch.randn(4), \ |
|
'list': [True, False, False], \ |
|
'tuple': (4, 5, 6), \ |
|
'bool': True, \ |
|
'int': 10, \ |
|
'float': 10., \ |
|
'array': np.random.randn(4), \ |
|
'str': "asdf", \ |
|
'none': None, \ |
|
} \ |
|
>>> transformed_data = to_list(data) |
|
|
|
.. note:: |
|
|
|
Now supports item type: :obj:`torch.Tensor`, :obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \ |
|
:obj:`tuple` and :obj:`None`. |
|
""" |
|
if item is None: |
|
return item |
|
elif isinstance(item, torch.Tensor): |
|
return item.tolist() |
|
elif isinstance(item, np.ndarray): |
|
return item.tolist() |
|
elif isinstance(item, list) or isinstance(item, tuple): |
|
return [to_list(t) for t in item] |
|
elif isinstance(item, dict): |
|
return {k: to_list(v) for k, v in item.items()} |
|
elif np.isscalar(item): |
|
return item |
|
else: |
|
raise TypeError("not support item type: {}".format(type(item))) |
|
|
|
|
|
def tensor_to_list(item: Any) -> Any: |
|
""" |
|
Overview: |
|
Convert ``torch.Tensor`` objects to ``list``, and keep their dtypes unchanged. |
|
Arguments: |
|
- item (:obj:`Any`): The item to be converted. |
|
Returns: |
|
- item (:obj:`Any`): The lists after conversion. |
|
|
|
Examples (2d-tensor): |
|
>>> t = torch.randn(3, 5) |
|
>>> tlist1 = tensor_to_list(t) |
|
>>> assert len(tlist1) == 3 |
|
>>> assert len(tlist1[0]) == 5 |
|
|
|
Examples (1d-tensor): |
|
>>> t = torch.randn(3, ) |
|
>>> tlist1 = tensor_to_list(t) |
|
>>> assert len(tlist1) == 3 |
|
|
|
Examples (list) |
|
>>> t = [torch.randn(5, ) for i in range(3)] |
|
>>> tlist1 = tensor_to_list(t) |
|
>>> assert len(tlist1) == 3 |
|
>>> assert len(tlist1[0]) == 5 |
|
|
|
Examples (dict): |
|
>>> td = {'t': torch.randn(3, 5)} |
|
>>> tdlist1 = tensor_to_list(td) |
|
>>> assert len(tdlist1['t']) == 3 |
|
>>> assert len(tdlist1['t'][0]) == 5 |
|
|
|
.. note:: |
|
|
|
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. |
|
""" |
|
if item is None: |
|
return item |
|
elif isinstance(item, torch.Tensor): |
|
return item.tolist() |
|
elif isinstance(item, list) or isinstance(item, tuple): |
|
return [tensor_to_list(t) for t in item] |
|
elif isinstance(item, dict): |
|
return {k: tensor_to_list(v) for k, v in item.items()} |
|
elif np.isscalar(item): |
|
return item |
|
else: |
|
raise TypeError("not support item type: {}".format(type(item))) |
|
|
|
|
|
def to_item(data: Any, ignore_error: bool = True) -> Any: |
|
""" |
|
Overview: |
|
Convert data to python native scalar (i.e. data item), and keep their dtypes unchanged. |
|
Arguments: |
|
- data (:obj:`Any`): The data that needs to be converted. |
|
- ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \ |
|
say, only the data can be transformed into a python native scalar will be returned. |
|
Returns: |
|
- data (:obj:`Any`): Converted data. |
|
|
|
Examples: |
|
>>>> data = { \ |
|
'tensor': torch.randn(1), \ |
|
'list': [True, False, torch.randn(1)], \ |
|
'tuple': (4, 5, 6), \ |
|
'bool': True, \ |
|
'int': 10, \ |
|
'float': 10., \ |
|
'array': np.random.randn(1), \ |
|
'str': "asdf", \ |
|
'none': None, \ |
|
} |
|
>>>> new_data = to_item(data) |
|
>>>> assert np.isscalar(new_data['tensor']) |
|
>>>> assert np.isscalar(new_data['array']) |
|
>>>> assert np.isscalar(new_data['list'][-1]) |
|
|
|
.. note:: |
|
|
|
Now supports item type: :obj:`torch.Tensor`, :obj:`torch.Tensor`, :obj:`ttorch.Tensor`, \ |
|
:obj:`bool`, :obj:`str`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. |
|
""" |
|
if data is None: |
|
return data |
|
elif isinstance(data, bool) or isinstance(data, str): |
|
return data |
|
elif np.isscalar(data): |
|
return data |
|
elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor) or isinstance(data, ttorch.Tensor): |
|
return data.item() |
|
elif isinstance(data, list) or isinstance(data, tuple): |
|
return [to_item(d) for d in data] |
|
elif isinstance(data, dict): |
|
new_data = {} |
|
for k, v in data.items(): |
|
if ignore_error: |
|
try: |
|
new_data[k] = to_item(v) |
|
except (ValueError, RuntimeError): |
|
pass |
|
else: |
|
new_data[k] = to_item(v) |
|
return new_data |
|
else: |
|
raise TypeError("not support data type: {}".format(data)) |
|
|
|
|
|
def same_shape(data: list) -> bool: |
|
""" |
|
Overview: |
|
Judge whether all data elements in a list have the same shapes. |
|
Arguments: |
|
- data (:obj:`list`): The list of data. |
|
Returns: |
|
- same (:obj:`bool`): Whether the list of data all have the same shape. |
|
|
|
Examples: |
|
>>> tlist = [torch.randn(3, 5) for i in range(5)] |
|
>>> assert same_shape(tlist) |
|
>>> tlist = [torch.randn(3, 5), torch.randn(4, 5)] |
|
>>> assert not same_shape(tlist) |
|
""" |
|
assert (isinstance(data, list)) |
|
shapes = [t.shape for t in data] |
|
return len(set(shapes)) == 1 |
|
|
|
|
|
class LogDict(dict): |
|
""" |
|
Overview: |
|
Derived from ``dict``. Would convert ``torch.Tensor`` to ``list`` for convenient logging. |
|
Interfaces: |
|
``_transform``, ``__setitem__``, ``update``. |
|
""" |
|
|
|
def _transform(self, data: Any) -> None: |
|
""" |
|
Overview: |
|
Convert tensor objects to lists for better logging. |
|
Arguments: |
|
- data (:obj:`Any`): The input data to be converted. |
|
""" |
|
if isinstance(data, torch.Tensor): |
|
new_data = data.tolist() |
|
else: |
|
new_data = data |
|
return new_data |
|
|
|
def __setitem__(self, key: Any, value: Any) -> None: |
|
""" |
|
Overview: |
|
Override the ``__setitem__`` function of built-in dict. |
|
Arguments: |
|
- key (:obj:`Any`): The key of the data item. |
|
- value (:obj:`Any`): The value of the data item. |
|
""" |
|
new_value = self._transform(value) |
|
super().__setitem__(key, new_value) |
|
|
|
def update(self, data: dict) -> None: |
|
""" |
|
Overview: |
|
Override the ``update`` function of built-in dict. |
|
Arguments: |
|
- data (:obj:`dict`): The dict for updating current object. |
|
""" |
|
for k, v in data.items(): |
|
self.__setitem__(k, v) |
|
|
|
|
|
def build_log_buffer() -> LogDict: |
|
""" |
|
Overview: |
|
Build log buffer, a subclass of dict, which can convert the input data into log format. |
|
Returns: |
|
- log_buffer (:obj:`LogDict`): Log buffer dict. |
|
Examples: |
|
>>> log_buffer = build_log_buffer() |
|
>>> log_buffer['not_tensor'] = torch.randn(3) |
|
>>> assert isinstance(log_buffer['not_tensor'], list) |
|
>>> assert len(log_buffer['not_tensor']) == 3 |
|
>>> log_buffer.update({'not_tensor': 4, 'a': 5}) |
|
>>> assert log_buffer['not_tensor'] == 4 |
|
""" |
|
return LogDict() |
|
|
|
|
|
class CudaFetcher(object): |
|
""" |
|
Overview: |
|
Fetch data from source, and transfer it to a specified device. |
|
Interfaces: |
|
``__init__``, ``__next__``, ``run``, ``close``. |
|
""" |
|
|
|
def __init__(self, data_source: Iterable, device: str, queue_size: int = 4, sleep: float = 0.1) -> None: |
|
""" |
|
Overview: |
|
Initialize the CudaFetcher object using the given arguments. |
|
Arguments: |
|
- data_source (:obj:`Iterable`): The iterable data source. |
|
- device (:obj:`str`): The device to put data to, such as "cuda:0". |
|
- queue_size (:obj:`int`): The internal size of queue, such as 4. |
|
- sleep (:obj:`float`): Sleeping time when the internal queue is full. |
|
""" |
|
self._source = data_source |
|
self._queue = Queue(maxsize=queue_size) |
|
self._stream = torch.cuda.Stream() |
|
self._producer_thread = Thread(target=self._producer, args=(), name='cuda_fetcher_producer') |
|
self._sleep = sleep |
|
self._device = device |
|
|
|
def __next__(self) -> Any: |
|
""" |
|
Overview: |
|
Response to the request for data. Return one data item from the internal queue. |
|
Returns: |
|
- item (:obj:`Any`): The data item on the required device. |
|
""" |
|
return self._queue.get() |
|
|
|
def run(self) -> None: |
|
""" |
|
Overview: |
|
Start ``producer`` thread: Keep fetching data from source, change the device, and put into \ |
|
``queue`` for request. |
|
Examples: |
|
>>> timer = EasyTimer() |
|
>>> dataloader = iter([torch.randn(3, 3) for _ in range(10)]) |
|
>>> dataloader = CudaFetcher(dataloader, device='cuda', sleep=0.1) |
|
>>> dataloader.run() |
|
>>> data = next(dataloader) |
|
""" |
|
self._end_flag = False |
|
self._producer_thread.start() |
|
|
|
def close(self) -> None: |
|
""" |
|
Overview: |
|
Stop ``producer`` thread by setting ``end_flag`` to ``True`` . |
|
""" |
|
self._end_flag = True |
|
|
|
def _producer(self) -> None: |
|
""" |
|
Overview: |
|
Keep fetching data from source, change the device, and put into ``queue`` for request. |
|
""" |
|
|
|
with torch.cuda.stream(self._stream): |
|
while not self._end_flag: |
|
if self._queue.full(): |
|
time.sleep(self._sleep) |
|
else: |
|
data = next(self._source) |
|
data = to_device(data, self._device) |
|
self._queue.put(data) |
|
|
|
|
|
def get_tensor_data(data: Any) -> Any: |
|
""" |
|
Overview: |
|
Get pure tensor data from the given data (without disturbing grad computation graph). |
|
Arguments: |
|
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). |
|
Returns: |
|
- output (:obj:`Any`): The output data. |
|
Examples: |
|
>>> a = { \ |
|
'tensor': torch.tensor([1, 2, 3.], requires_grad=True), \ |
|
'list': [torch.tensor([1, 2, 3.], requires_grad=True) for _ in range(2)], \ |
|
'none': None \ |
|
} |
|
>>> tensor_a = get_tensor_data(a) |
|
>>> assert not tensor_a['tensor'].requires_grad |
|
>>> for t in tensor_a['list']: |
|
>>> assert not t.requires_grad |
|
""" |
|
if isinstance(data, torch.Tensor): |
|
return data.data.clone() |
|
elif data is None: |
|
return None |
|
elif isinstance(data, Sequence): |
|
return [get_tensor_data(d) for d in data] |
|
elif isinstance(data, dict): |
|
return {k: get_tensor_data(v) for k, v in data.items()} |
|
else: |
|
raise TypeError("not support type in get_tensor_data: {}".format(type(data))) |
|
|
|
|
|
def unsqueeze(data: Any, dim: int = 0) -> Any: |
|
""" |
|
Overview: |
|
Unsqueeze the tensor data. |
|
Arguments: |
|
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). |
|
- dim (:obj:`int`): The dimension to be unsqueezed. |
|
Returns: |
|
- output (:obj:`Any`): The output data. |
|
|
|
Examples (tensor): |
|
>>> t = torch.randn(3, 3) |
|
>>> tt = unsqueeze(t, dim=0) |
|
>>> assert tt.shape == torch.Shape([1, 3, 3]) |
|
|
|
Examples (list): |
|
>>> t = [torch.randn(3, 3)] |
|
>>> tt = unsqueeze(t, dim=0) |
|
>>> assert tt[0].shape == torch.Shape([1, 3, 3]) |
|
|
|
Examples (dict): |
|
>>> t = {"t": torch.randn(3, 3)} |
|
>>> tt = unsqueeze(t, dim=0) |
|
>>> assert tt["t"].shape == torch.Shape([1, 3, 3]) |
|
""" |
|
if isinstance(data, torch.Tensor): |
|
return data.unsqueeze(dim) |
|
elif isinstance(data, Sequence): |
|
return [unsqueeze(d) for d in data] |
|
elif isinstance(data, dict): |
|
return {k: unsqueeze(v, 0) for k, v in data.items()} |
|
else: |
|
raise TypeError("not support type in unsqueeze: {}".format(type(data))) |
|
|
|
|
|
def squeeze(data: Any, dim: int = 0) -> Any: |
|
""" |
|
Overview: |
|
Squeeze the tensor data. |
|
Arguments: |
|
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). |
|
- dim (:obj:`int`): The dimension to be Squeezed. |
|
Returns: |
|
- output (:obj:`Any`): The output data. |
|
|
|
Examples (tensor): |
|
>>> t = torch.randn(1, 3, 3) |
|
>>> tt = squeeze(t, dim=0) |
|
>>> assert tt.shape == torch.Shape([3, 3]) |
|
|
|
Examples (list): |
|
>>> t = [torch.randn(1, 3, 3)] |
|
>>> tt = squeeze(t, dim=0) |
|
>>> assert tt[0].shape == torch.Shape([3, 3]) |
|
|
|
Examples (dict): |
|
>>> t = {"t": torch.randn(1, 3, 3)} |
|
>>> tt = squeeze(t, dim=0) |
|
>>> assert tt["t"].shape == torch.Shape([3, 3]) |
|
""" |
|
if isinstance(data, torch.Tensor): |
|
return data.squeeze(dim) |
|
elif isinstance(data, Sequence): |
|
return [squeeze(d) for d in data] |
|
elif isinstance(data, dict): |
|
return {k: squeeze(v, 0) for k, v in data.items()} |
|
else: |
|
raise TypeError("not support type in squeeze: {}".format(type(data))) |
|
|
|
|
|
def get_null_data(template: Any, num: int) -> List[Any]: |
|
""" |
|
Overview: |
|
Get null data given an input template. |
|
Arguments: |
|
- template (:obj:`Any`): The template data. |
|
- num (:obj:`int`): The number of null data items to generate. |
|
Returns: |
|
- output (:obj:`List[Any]`): The generated null data. |
|
|
|
Examples: |
|
>>> temp = {'obs': [1, 2, 3], 'action': 1, 'done': False, 'reward': torch.tensor(1.)} |
|
>>> null_data = get_null_data(temp, 2) |
|
>>> assert len(null_data) ==2 |
|
>>> assert null_data[0]['null'] and null_data[0]['done'] |
|
""" |
|
ret = [] |
|
for _ in range(num): |
|
data = copy.deepcopy(template) |
|
data['null'] = True |
|
data['done'] = True |
|
data['reward'].zero_() |
|
ret.append(data) |
|
return ret |
|
|
|
|
|
def zeros_like(h: Any) -> Any: |
|
""" |
|
Overview: |
|
Generate zero-tensors like the input data. |
|
Arguments: |
|
- h (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). |
|
Returns: |
|
- output (:obj:`Any`): The output zero-tensors. |
|
|
|
Examples (tensor): |
|
>>> t = torch.randn(3, 3) |
|
>>> tt = zeros_like(t) |
|
>>> assert tt.shape == torch.Shape([3, 3]) |
|
>>> assert torch.sum(torch.abs(tt)) < 1e-8 |
|
|
|
Examples (list): |
|
>>> t = [torch.randn(3, 3)] |
|
>>> tt = zeros_like(t) |
|
>>> assert tt[0].shape == torch.Shape([3, 3]) |
|
>>> assert torch.sum(torch.abs(tt[0])) < 1e-8 |
|
|
|
Examples (dict): |
|
>>> t = {"t": torch.randn(3, 3)} |
|
>>> tt = zeros_like(t) |
|
>>> assert tt["t"].shape == torch.Shape([3, 3]) |
|
>>> assert torch.sum(torch.abs(tt["t"])) < 1e-8 |
|
""" |
|
if isinstance(h, torch.Tensor): |
|
return torch.zeros_like(h) |
|
elif isinstance(h, (list, tuple)): |
|
return [zeros_like(t) for t in h] |
|
elif isinstance(h, dict): |
|
return {k: zeros_like(v) for k, v in h.items()} |
|
else: |
|
raise TypeError("not support type: {}".format(h)) |
|
|