gomoku / DI-engine /ding /data /shm_buffer.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
5.01 kB
from typing import Any, Optional, Union, Tuple, Dict
from multiprocessing import Array
import ctypes
import numpy as np
import torch
_NTYPE_TO_CTYPE = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShmBuffer():
"""
Overview:
Shared memory buffer to store numpy array.
"""
def __init__(
self,
dtype: Union[type, np.dtype],
shape: Tuple[int],
copy_on_get: bool = True,
ctype: Optional[type] = None
) -> None:
"""
Overview:
Initialize the buffer.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
"""
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
dtype = dtype.type
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
self.copy_on_get = copy_on_get
self.ctype = ctype
def fill(self, src_arr: np.ndarray) -> None:
"""
Overview:
Fill the shared memory buffer with a numpy array. (Replace the original one.)
Arguments:
- src_arr (:obj:`np.ndarray`): array to fill the buffer.
"""
assert isinstance(src_arr, np.ndarray), type(src_arr)
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
# so we reshape dst_arr rather than flatten src_arr
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
np.copyto(dst_arr, src_arr)
def get(self) -> np.ndarray:
"""
Overview:
Get the array stored in the buffer.
Return:
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
"""
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
if self.copy_on_get:
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
if self.ctype is torch.Tensor:
data = torch.from_numpy(data)
return data
class ShmBufferContainer(object):
"""
Overview:
Support multiple shared memory buffers. Each key-value is name-buffer.
"""
def __init__(
self,
dtype: Union[Dict[Any, type], type, np.dtype],
shape: Union[Dict[Any, tuple], tuple],
copy_on_get: bool = True
) -> None:
"""
Overview:
Initialize the buffer container.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
multiple buffers; If `tuple`, use single buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
"""
if isinstance(shape, dict):
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
elif isinstance(shape, (tuple, list)):
self._data = ShmBuffer(dtype, shape, copy_on_get)
else:
raise RuntimeError("not support shape: {}".format(shape))
self._shape = shape
def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
"""
Overview:
Fill the one or many shared memory buffer.
Arguments:
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
"""
if isinstance(self._shape, dict):
for k in self._shape.keys():
self._data[k].fill(src_arr[k])
elif isinstance(self._shape, (tuple, list)):
self._data.fill(src_arr)
def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
"""
Overview:
Get the one or many arrays stored in the buffer.
Return:
- data (:obj:`np.ndarray`): The array(s) stored in the buffer.
"""
if isinstance(self._shape, dict):
return {k: self._data[k].get() for k in self._shape.keys()}
elif isinstance(self._shape, (tuple, list)):
return self._data.get()