zjowowen's picture
init space
079c32c
raw
history blame
25.8 kB
from typing import Iterable, Any, Optional, List
from collections.abc import Sequence
import numbers
import time
import copy
from threading import Thread
from queue import Queue
import numpy as np
import torch
import treetensor.torch as ttorch
from ding.utils.default_helper import get_shape0
def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
"""
Overview:
Transfer data to certain device.
Arguments:
- item (:obj:`Any`): The item to be transferred.
- device (:obj:`str`): The device wanted.
- ignore_keys (:obj:`list`): The keys to be ignored in transfer, default set to empty.
Returns:
- item (:obj:`Any`): The transferred item.
Examples:
>>> setup_data_dict['module'] = nn.Linear(3, 5)
>>> device = 'cuda'
>>> cuda_d = to_device(setup_data_dict, device, ignore_keys=['module'])
>>> assert cuda_d['module'].weight.device == torch.device('cpu')
Examples:
>>> setup_data_dict['module'] = nn.Linear(3, 5)
>>> device = 'cuda'
>>> cuda_d = to_device(setup_data_dict, device)
>>> assert cuda_d['module'].weight.device == torch.device('cuda:0')
.. note:
Now supports item type: :obj:`torch.nn.Module`, :obj:`torch.Tensor`, :obj:`Sequence`, \
:obj:`dict`, :obj:`numbers.Integral`, :obj:`numbers.Real`, :obj:`np.ndarray`, :obj:`str` and :obj:`None`.
"""
if isinstance(item, torch.nn.Module):
return item.to(device)
elif isinstance(item, ttorch.Tensor):
if 'prev_state' in item:
prev_state = to_device(item.prev_state, device)
del item.prev_state
item = item.to(device)
item.prev_state = prev_state
return item
else:
return item.to(device)
elif isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, Sequence):
if isinstance(item, str):
return item
else:
return [to_device(t, device) for t in item]
elif isinstance(item, dict):
new_item = {}
for k in item.keys():
if k in ignore_keys:
new_item[k] = item[k]
else:
new_item[k] = to_device(item[k], device)
return new_item
elif isinstance(item, numbers.Integral) or isinstance(item, numbers.Real):
return item
elif isinstance(item, np.ndarray) or isinstance(item, np.bool_):
return item
elif item is None or isinstance(item, str):
return item
elif isinstance(item, torch.distributions.Distribution): # for compatibility
return item
else:
raise TypeError("not support item type: {}".format(type(item)))
def to_dtype(item: Any, dtype: type) -> Any:
"""
Overview:
Change data to certain dtype.
Arguments:
- item (:obj:`Any`): The item for changing the dtype.
- dtype (:obj:`type`): The type wanted.
Returns:
- item (:obj:`object`): The item with changed dtype.
Examples (tensor):
>>> t = torch.randint(0, 10, (3, 5))
>>> tfloat = to_dtype(t, torch.float)
>>> assert tfloat.dtype == torch.float
Examples (list):
>>> tlist = [torch.randint(0, 10, (3, 5))]
>>> tlfloat = to_dtype(tlist, torch.float)
>>> assert tlfloat[0].dtype == torch.float
Examples (dict):
>>> tdict = {'t': torch.randint(0, 10, (3, 5))}
>>> tdictf = to_dtype(tdict, torch.float)
>>> assert tdictf['t'].dtype == torch.float
.. note:
Now supports item type: :obj:`torch.Tensor`, :obj:`Sequence`, :obj:`dict`.
"""
if isinstance(item, torch.Tensor):
return item.to(dtype=dtype)
elif isinstance(item, Sequence):
return [to_dtype(t, dtype) for t in item]
elif isinstance(item, dict):
return {k: to_dtype(item[k], dtype) for k in item.keys()}
else:
raise TypeError("not support item type: {}".format(type(item)))
def to_tensor(
item: Any, dtype: Optional[torch.dtype] = None, ignore_keys: list = [], transform_scalar: bool = True
) -> Any:
"""
Overview:
Convert ``numpy.ndarray`` object to ``torch.Tensor``.
Arguments:
- item (:obj:`Any`): The ``numpy.ndarray`` objects to be converted. It can be exactly a ``numpy.ndarray`` \
object or a container (list, tuple or dict) that contains several ``numpy.ndarray`` objects.
- dtype (:obj:`torch.dtype`): The type of wanted tensor. If set to ``None``, its dtype will be unchanged.
- ignore_keys (:obj:`list`): If the ``item`` is a dict, values whose keys are in ``ignore_keys`` will not \
be converted.
- transform_scalar (:obj:`bool`): If set to ``True``, a scalar will be also converted to a tensor object.
Returns:
- item (:obj:`Any`): The converted tensors.
Examples (scalar):
>>> i = 10
>>> t = to_tensor(i)
>>> assert t.item() == i
Examples (dict):
>>> d = {'i': i}
>>> dt = to_tensor(d, torch.int)
>>> assert dt['i'].item() == i
Examples (named tuple):
>>> data_type = namedtuple('data_type', ['x', 'y'])
>>> inputs = data_type(np.random.random(3), 4)
>>> outputs = to_tensor(inputs, torch.float32)
>>> assert type(outputs) == data_type
>>> assert isinstance(outputs.x, torch.Tensor)
>>> assert isinstance(outputs.y, torch.Tensor)
>>> assert outputs.x.dtype == torch.float32
>>> assert outputs.y.dtype == torch.float32
.. note:
Now supports item type: :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
"""
def transform(d):
if dtype is None:
return torch.as_tensor(d)
else:
return torch.tensor(d, dtype=dtype)
if isinstance(item, dict):
new_data = {}
for k, v in item.items():
if k in ignore_keys:
new_data[k] = v
else:
new_data[k] = to_tensor(v, dtype, ignore_keys, transform_scalar)
return new_data
elif isinstance(item, list) or isinstance(item, tuple):
if len(item) == 0:
return []
elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real):
return transform(item)
elif hasattr(item, '_fields'): # namedtuple
return type(item)(*[to_tensor(t, dtype) for t in item])
else:
new_data = []
for t in item:
new_data.append(to_tensor(t, dtype, ignore_keys, transform_scalar))
return new_data
elif isinstance(item, np.ndarray):
if dtype is None:
if item.dtype == np.float64:
return torch.FloatTensor(item)
else:
return torch.from_numpy(item)
else:
return torch.from_numpy(item).to(dtype)
elif isinstance(item, bool) or isinstance(item, str):
return item
elif np.isscalar(item):
if transform_scalar:
if dtype is None:
return torch.as_tensor(item)
else:
return torch.as_tensor(item).to(dtype)
else:
return item
elif item is None:
return None
elif isinstance(item, torch.Tensor):
if dtype is None:
return item
else:
return item.to(dtype)
else:
raise TypeError("not support item type: {}".format(type(item)))
def to_ndarray(item: Any, dtype: np.dtype = None) -> Any:
"""
Overview:
Convert ``torch.Tensor`` to ``numpy.ndarray``.
Arguments:
- item (:obj:`Any`): The ``torch.Tensor`` objects to be converted. It can be exactly a ``torch.Tensor`` \
object or a container (list, tuple or dict) that contains several ``torch.Tensor`` objects.
- dtype (:obj:`np.dtype`): The type of wanted array. If set to ``None``, its dtype will be unchanged.
Returns:
- item (:obj:`object`): The changed arrays.
Examples (ndarray):
>>> t = torch.randn(3, 5)
>>> tarray1 = to_ndarray(t)
>>> assert tarray1.shape == (3, 5)
>>> assert isinstance(tarray1, np.ndarray)
Examples (list):
>>> t = [torch.randn(5, ) for i in range(3)]
>>> tarray1 = to_ndarray(t, np.float32)
>>> assert isinstance(tarray1, list)
>>> assert tarray1[0].shape == (5, )
>>> assert isinstance(tarray1[0], np.ndarray)
.. note:
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
"""
def transform(d):
if dtype is None:
return np.array(d)
else:
return np.array(d, dtype=dtype)
if isinstance(item, dict):
new_data = {}
for k, v in item.items():
new_data[k] = to_ndarray(v, dtype)
return new_data
elif isinstance(item, list) or isinstance(item, tuple):
if len(item) == 0:
return None
elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real):
return transform(item)
elif hasattr(item, '_fields'): # namedtuple
return type(item)(*[to_ndarray(t, dtype) for t in item])
else:
new_data = []
for t in item:
new_data.append(to_ndarray(t, dtype))
return new_data
elif isinstance(item, torch.Tensor):
if dtype is None:
return item.numpy()
else:
return item.numpy().astype(dtype)
elif isinstance(item, np.ndarray):
if dtype is None:
return item
else:
return item.astype(dtype)
elif isinstance(item, bool) or isinstance(item, str):
return item
elif np.isscalar(item):
if dtype is None:
return np.array(item)
else:
return np.array(item, dtype=dtype)
elif item is None:
return None
else:
raise TypeError("not support item type: {}".format(type(item)))
def to_list(item: Any) -> Any:
"""
Overview:
Convert ``torch.Tensor``, ``numpy.ndarray`` objects to ``list`` objects, and keep their dtypes unchanged.
Arguments:
- item (:obj:`Any`): The item to be converted.
Returns:
- item (:obj:`Any`): The list after conversion.
Examples:
>>> data = { \
'tensor': torch.randn(4), \
'list': [True, False, False], \
'tuple': (4, 5, 6), \
'bool': True, \
'int': 10, \
'float': 10., \
'array': np.random.randn(4), \
'str': "asdf", \
'none': None, \
} \
>>> transformed_data = to_list(data)
.. note::
Now supports item type: :obj:`torch.Tensor`, :obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \
:obj:`tuple` and :obj:`None`.
"""
if item is None:
return item
elif isinstance(item, torch.Tensor):
return item.tolist()
elif isinstance(item, np.ndarray):
return item.tolist()
elif isinstance(item, list) or isinstance(item, tuple):
return [to_list(t) for t in item]
elif isinstance(item, dict):
return {k: to_list(v) for k, v in item.items()}
elif np.isscalar(item):
return item
else:
raise TypeError("not support item type: {}".format(type(item)))
def tensor_to_list(item: Any) -> Any:
"""
Overview:
Convert ``torch.Tensor`` objects to ``list``, and keep their dtypes unchanged.
Arguments:
- item (:obj:`Any`): The item to be converted.
Returns:
- item (:obj:`Any`): The lists after conversion.
Examples (2d-tensor):
>>> t = torch.randn(3, 5)
>>> tlist1 = tensor_to_list(t)
>>> assert len(tlist1) == 3
>>> assert len(tlist1[0]) == 5
Examples (1d-tensor):
>>> t = torch.randn(3, )
>>> tlist1 = tensor_to_list(t)
>>> assert len(tlist1) == 3
Examples (list)
>>> t = [torch.randn(5, ) for i in range(3)]
>>> tlist1 = tensor_to_list(t)
>>> assert len(tlist1) == 3
>>> assert len(tlist1[0]) == 5
Examples (dict):
>>> td = {'t': torch.randn(3, 5)}
>>> tdlist1 = tensor_to_list(td)
>>> assert len(tdlist1['t']) == 3
>>> assert len(tdlist1['t'][0]) == 5
.. note::
Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
"""
if item is None:
return item
elif isinstance(item, torch.Tensor):
return item.tolist()
elif isinstance(item, list) or isinstance(item, tuple):
return [tensor_to_list(t) for t in item]
elif isinstance(item, dict):
return {k: tensor_to_list(v) for k, v in item.items()}
elif np.isscalar(item):
return item
else:
raise TypeError("not support item type: {}".format(type(item)))
def to_item(data: Any, ignore_error: bool = True) -> Any:
"""
Overview:
Convert data to python native scalar (i.e. data item), and keep their dtypes unchanged.
Arguments:
- data (:obj:`Any`): The data that needs to be converted.
- ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \
say, only the data can be transformed into a python native scalar will be returned.
Returns:
- data (:obj:`Any`): Converted data.
Examples:
>>>> data = { \
'tensor': torch.randn(1), \
'list': [True, False, torch.randn(1)], \
'tuple': (4, 5, 6), \
'bool': True, \
'int': 10, \
'float': 10., \
'array': np.random.randn(1), \
'str': "asdf", \
'none': None, \
}
>>>> new_data = to_item(data)
>>>> assert np.isscalar(new_data['tensor'])
>>>> assert np.isscalar(new_data['array'])
>>>> assert np.isscalar(new_data['list'][-1])
.. note::
Now supports item type: :obj:`torch.Tensor`, :obj:`torch.Tensor`, :obj:`ttorch.Tensor`, \
:obj:`bool`, :obj:`str`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
"""
if data is None:
return data
elif isinstance(data, bool) or isinstance(data, str):
return data
elif np.isscalar(data):
return data
elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor) or isinstance(data, ttorch.Tensor):
return data.item()
elif isinstance(data, list) or isinstance(data, tuple):
return [to_item(d) for d in data]
elif isinstance(data, dict):
new_data = {}
for k, v in data.items():
if ignore_error:
try:
new_data[k] = to_item(v)
except (ValueError, RuntimeError):
pass
else:
new_data[k] = to_item(v)
return new_data
else:
raise TypeError("not support data type: {}".format(data))
def same_shape(data: list) -> bool:
"""
Overview:
Judge whether all data elements in a list have the same shapes.
Arguments:
- data (:obj:`list`): The list of data.
Returns:
- same (:obj:`bool`): Whether the list of data all have the same shape.
Examples:
>>> tlist = [torch.randn(3, 5) for i in range(5)]
>>> assert same_shape(tlist)
>>> tlist = [torch.randn(3, 5), torch.randn(4, 5)]
>>> assert not same_shape(tlist)
"""
assert (isinstance(data, list))
shapes = [t.shape for t in data]
return len(set(shapes)) == 1
class LogDict(dict):
"""
Overview:
Derived from ``dict``. Would convert ``torch.Tensor`` to ``list`` for convenient logging.
Interfaces:
``_transform``, ``__setitem__``, ``update``.
"""
def _transform(self, data: Any) -> None:
"""
Overview:
Convert tensor objects to lists for better logging.
Arguments:
- data (:obj:`Any`): The input data to be converted.
"""
if isinstance(data, torch.Tensor):
new_data = data.tolist()
else:
new_data = data
return new_data
def __setitem__(self, key: Any, value: Any) -> None:
"""
Overview:
Override the ``__setitem__`` function of built-in dict.
Arguments:
- key (:obj:`Any`): The key of the data item.
- value (:obj:`Any`): The value of the data item.
"""
new_value = self._transform(value)
super().__setitem__(key, new_value)
def update(self, data: dict) -> None:
"""
Overview:
Override the ``update`` function of built-in dict.
Arguments:
- data (:obj:`dict`): The dict for updating current object.
"""
for k, v in data.items():
self.__setitem__(k, v)
def build_log_buffer() -> LogDict:
"""
Overview:
Build log buffer, a subclass of dict, which can convert the input data into log format.
Returns:
- log_buffer (:obj:`LogDict`): Log buffer dict.
Examples:
>>> log_buffer = build_log_buffer()
>>> log_buffer['not_tensor'] = torch.randn(3)
>>> assert isinstance(log_buffer['not_tensor'], list)
>>> assert len(log_buffer['not_tensor']) == 3
>>> log_buffer.update({'not_tensor': 4, 'a': 5})
>>> assert log_buffer['not_tensor'] == 4
"""
return LogDict()
class CudaFetcher(object):
"""
Overview:
Fetch data from source, and transfer it to a specified device.
Interfaces:
``__init__``, ``__next__``, ``run``, ``close``.
"""
def __init__(self, data_source: Iterable, device: str, queue_size: int = 4, sleep: float = 0.1) -> None:
"""
Overview:
Initialize the CudaFetcher object using the given arguments.
Arguments:
- data_source (:obj:`Iterable`): The iterable data source.
- device (:obj:`str`): The device to put data to, such as "cuda:0".
- queue_size (:obj:`int`): The internal size of queue, such as 4.
- sleep (:obj:`float`): Sleeping time when the internal queue is full.
"""
self._source = data_source
self._queue = Queue(maxsize=queue_size)
self._stream = torch.cuda.Stream()
self._producer_thread = Thread(target=self._producer, args=(), name='cuda_fetcher_producer')
self._sleep = sleep
self._device = device
def __next__(self) -> Any:
"""
Overview:
Response to the request for data. Return one data item from the internal queue.
Returns:
- item (:obj:`Any`): The data item on the required device.
"""
return self._queue.get()
def run(self) -> None:
"""
Overview:
Start ``producer`` thread: Keep fetching data from source, change the device, and put into \
``queue`` for request.
Examples:
>>> timer = EasyTimer()
>>> dataloader = iter([torch.randn(3, 3) for _ in range(10)])
>>> dataloader = CudaFetcher(dataloader, device='cuda', sleep=0.1)
>>> dataloader.run()
>>> data = next(dataloader)
"""
self._end_flag = False
self._producer_thread.start()
def close(self) -> None:
"""
Overview:
Stop ``producer`` thread by setting ``end_flag`` to ``True`` .
"""
self._end_flag = True
def _producer(self) -> None:
"""
Overview:
Keep fetching data from source, change the device, and put into ``queue`` for request.
"""
with torch.cuda.stream(self._stream):
while not self._end_flag:
if self._queue.full():
time.sleep(self._sleep)
else:
data = next(self._source)
data = to_device(data, self._device)
self._queue.put(data)
def get_tensor_data(data: Any) -> Any:
"""
Overview:
Get pure tensor data from the given data (without disturbing grad computation graph).
Arguments:
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
Returns:
- output (:obj:`Any`): The output data.
Examples:
>>> a = { \
'tensor': torch.tensor([1, 2, 3.], requires_grad=True), \
'list': [torch.tensor([1, 2, 3.], requires_grad=True) for _ in range(2)], \
'none': None \
}
>>> tensor_a = get_tensor_data(a)
>>> assert not tensor_a['tensor'].requires_grad
>>> for t in tensor_a['list']:
>>> assert not t.requires_grad
"""
if isinstance(data, torch.Tensor):
return data.data.clone()
elif data is None:
return None
elif isinstance(data, Sequence):
return [get_tensor_data(d) for d in data]
elif isinstance(data, dict):
return {k: get_tensor_data(v) for k, v in data.items()}
else:
raise TypeError("not support type in get_tensor_data: {}".format(type(data)))
def unsqueeze(data: Any, dim: int = 0) -> Any:
"""
Overview:
Unsqueeze the tensor data.
Arguments:
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
- dim (:obj:`int`): The dimension to be unsqueezed.
Returns:
- output (:obj:`Any`): The output data.
Examples (tensor):
>>> t = torch.randn(3, 3)
>>> tt = unsqueeze(t, dim=0)
>>> assert tt.shape == torch.Shape([1, 3, 3])
Examples (list):
>>> t = [torch.randn(3, 3)]
>>> tt = unsqueeze(t, dim=0)
>>> assert tt[0].shape == torch.Shape([1, 3, 3])
Examples (dict):
>>> t = {"t": torch.randn(3, 3)}
>>> tt = unsqueeze(t, dim=0)
>>> assert tt["t"].shape == torch.Shape([1, 3, 3])
"""
if isinstance(data, torch.Tensor):
return data.unsqueeze(dim)
elif isinstance(data, Sequence):
return [unsqueeze(d) for d in data]
elif isinstance(data, dict):
return {k: unsqueeze(v, 0) for k, v in data.items()}
else:
raise TypeError("not support type in unsqueeze: {}".format(type(data)))
def squeeze(data: Any, dim: int = 0) -> Any:
"""
Overview:
Squeeze the tensor data.
Arguments:
- data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
- dim (:obj:`int`): The dimension to be Squeezed.
Returns:
- output (:obj:`Any`): The output data.
Examples (tensor):
>>> t = torch.randn(1, 3, 3)
>>> tt = squeeze(t, dim=0)
>>> assert tt.shape == torch.Shape([3, 3])
Examples (list):
>>> t = [torch.randn(1, 3, 3)]
>>> tt = squeeze(t, dim=0)
>>> assert tt[0].shape == torch.Shape([3, 3])
Examples (dict):
>>> t = {"t": torch.randn(1, 3, 3)}
>>> tt = squeeze(t, dim=0)
>>> assert tt["t"].shape == torch.Shape([3, 3])
"""
if isinstance(data, torch.Tensor):
return data.squeeze(dim)
elif isinstance(data, Sequence):
return [squeeze(d) for d in data]
elif isinstance(data, dict):
return {k: squeeze(v, 0) for k, v in data.items()}
else:
raise TypeError("not support type in squeeze: {}".format(type(data)))
def get_null_data(template: Any, num: int) -> List[Any]:
"""
Overview:
Get null data given an input template.
Arguments:
- template (:obj:`Any`): The template data.
- num (:obj:`int`): The number of null data items to generate.
Returns:
- output (:obj:`List[Any]`): The generated null data.
Examples:
>>> temp = {'obs': [1, 2, 3], 'action': 1, 'done': False, 'reward': torch.tensor(1.)}
>>> null_data = get_null_data(temp, 2)
>>> assert len(null_data) ==2
>>> assert null_data[0]['null'] and null_data[0]['done']
"""
ret = []
for _ in range(num):
data = copy.deepcopy(template)
data['null'] = True
data['done'] = True
data['reward'].zero_()
ret.append(data)
return ret
def zeros_like(h: Any) -> Any:
"""
Overview:
Generate zero-tensors like the input data.
Arguments:
- h (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
Returns:
- output (:obj:`Any`): The output zero-tensors.
Examples (tensor):
>>> t = torch.randn(3, 3)
>>> tt = zeros_like(t)
>>> assert tt.shape == torch.Shape([3, 3])
>>> assert torch.sum(torch.abs(tt)) < 1e-8
Examples (list):
>>> t = [torch.randn(3, 3)]
>>> tt = zeros_like(t)
>>> assert tt[0].shape == torch.Shape([3, 3])
>>> assert torch.sum(torch.abs(tt[0])) < 1e-8
Examples (dict):
>>> t = {"t": torch.randn(3, 3)}
>>> tt = zeros_like(t)
>>> assert tt["t"].shape == torch.Shape([3, 3])
>>> assert torch.sum(torch.abs(tt["t"])) < 1e-8
"""
if isinstance(h, torch.Tensor):
return torch.zeros_like(h)
elif isinstance(h, (list, tuple)):
return [zeros_like(t) for t in h]
elif isinstance(h, dict):
return {k: zeros_like(v) for k, v in h.items()}
else:
raise TypeError("not support type: {}".format(h))