|
from typing import Optional, Union, List, Tuple, Dict |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import treetensor.torch as ttorch |
|
|
|
import ding |
|
from ding.torch_utils.network.normalization import build_normalization |
|
if ding.enable_hpc_rl: |
|
from hpc_rll.torch_utils.network.rnn import LSTM as HPCLSTM |
|
else: |
|
HPCLSTM = None |
|
|
|
|
|
def is_sequence(data): |
|
""" |
|
Overview: |
|
Determines if the input data is of type list or tuple. |
|
Arguments: |
|
- data: The input data to be checked. |
|
Returns: |
|
- boolean: True if the input is a list or a tuple, False otherwise. |
|
""" |
|
return isinstance(data, list) or isinstance(data, tuple) |
|
|
|
|
|
def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor: |
|
""" |
|
Overview: |
|
Generates a boolean mask for a batch of sequences with differing lengths. |
|
Arguments: |
|
- lengths (:obj:`torch.Tensor`): A tensor with the lengths of each sequence. Shape could be (n, 1) or (n). |
|
- max_len (:obj:`int`, optional): The padding size. If max_len is None, the padding size is the max length of \ |
|
sequences. |
|
Returns: |
|
- masks (:obj:`torch.BoolTensor`): A boolean mask tensor. The mask has the same device as lengths. |
|
""" |
|
if len(lengths.shape) == 1: |
|
lengths = lengths.unsqueeze(dim=1) |
|
bz = lengths.numel() |
|
if max_len is None: |
|
max_len = lengths.max() |
|
else: |
|
max_len = min(max_len, lengths.max()) |
|
return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) |
|
|
|
|
|
class LSTMForwardWrapper(object): |
|
""" |
|
Overview: |
|
Class providing methods to use before and after the LSTM `forward` method. |
|
Wraps the LSTM `forward` method. |
|
Interfaces: |
|
``_before_forward``, ``_after_forward`` |
|
""" |
|
|
|
def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Preprocesses the inputs and previous states before the LSTM `forward` method. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Input vector of the LSTM cell. Shape: [seq_len, batch_size, input_size] |
|
- prev_state (:obj:`Union[None, List[Dict]]`): Previous state tensor. Shape: [num_directions*num_layers, \ |
|
batch_size, hidden_size]. If None, prv_state will be initialized to all zeros. |
|
Returns: |
|
- prev_state (:obj:`torch.Tensor`): Preprocessed previous state for the LSTM batch. |
|
""" |
|
assert hasattr(self, 'num_layers') |
|
assert hasattr(self, 'hidden_size') |
|
seq_len, batch_size = inputs.shape[:2] |
|
if prev_state is None: |
|
num_directions = 1 |
|
zeros = torch.zeros( |
|
num_directions * self.num_layers, |
|
batch_size, |
|
self.hidden_size, |
|
dtype=inputs.dtype, |
|
device=inputs.device |
|
) |
|
prev_state = (zeros, zeros) |
|
elif is_sequence(prev_state): |
|
if len(prev_state) != batch_size: |
|
raise RuntimeError( |
|
"prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) |
|
) |
|
num_directions = 1 |
|
zeros = torch.zeros( |
|
num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device |
|
) |
|
state = [] |
|
for prev in prev_state: |
|
if prev is None: |
|
state.append([zeros, zeros]) |
|
else: |
|
if isinstance(prev, (Dict, ttorch.Tensor)): |
|
state.append([v for v in prev.values()]) |
|
else: |
|
state.append(prev) |
|
state = list(zip(*state)) |
|
prev_state = [torch.cat(t, dim=1) for t in state] |
|
elif isinstance(prev_state, dict): |
|
prev_state = list(prev_state.values()) |
|
else: |
|
raise TypeError("not support prev_state type: {}".format(type(prev_state))) |
|
return prev_state |
|
|
|
def _after_forward(self, |
|
next_state: Tuple[torch.Tensor], |
|
list_next_state: bool = False) -> Union[List[Dict], Dict[str, torch.Tensor]]: |
|
""" |
|
Overview: |
|
Post-processes the next_state after the LSTM `forward` method. |
|
Arguments: |
|
- next_state (:obj:`Tuple[torch.Tensor]`): Tuple containing the next state (h, c). |
|
- list_next_state (:obj:`bool`, optional): Determines the format of the returned next_state. \ |
|
If True, returns next_state in list format. Default is False. |
|
Returns: |
|
- next_state(:obj:`Union[List[Dict], Dict[str, torch.Tensor]]`): The post-processed next_state. |
|
""" |
|
if list_next_state: |
|
h, c = next_state |
|
batch_size = h.shape[1] |
|
next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] |
|
next_state = list(zip(*next_state)) |
|
next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state] |
|
else: |
|
next_state = {k: v for k, v in zip(['h', 'c'], next_state)} |
|
return next_state |
|
|
|
|
|
class LSTM(nn.Module, LSTMForwardWrapper): |
|
""" |
|
Overview: |
|
Implementation of an LSTM cell with Layer Normalization (LN). |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
|
|
.. note:: |
|
|
|
For a primer on LSTM, refer to https://zhuanlan.zhihu.com/p/32085405. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
num_layers: int, |
|
norm_type: Optional[str] = None, |
|
dropout: float = 0. |
|
) -> None: |
|
""" |
|
Overview: |
|
Initialize LSTM cell parameters. |
|
Arguments: |
|
- input_size (:obj:`int`): Size of the input vector. |
|
- hidden_size (:obj:`int`): Size of the hidden state vector. |
|
- num_layers (:obj:`int`): Number of LSTM layers. |
|
- norm_type (:obj:`Optional[str]`): Normalization type, default is None. |
|
- dropout (:obj:`float`): Dropout rate, default is 0. |
|
""" |
|
super(LSTM, self).__init__() |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
|
|
norm_func = build_normalization(norm_type) |
|
self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) |
|
self.wx = nn.ParameterList() |
|
self.wh = nn.ParameterList() |
|
dims = [input_size] + [hidden_size] * num_layers |
|
for l in range(num_layers): |
|
self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) |
|
self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) |
|
self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) |
|
self.use_dropout = dropout > 0. |
|
if self.use_dropout: |
|
self.dropout = nn.Dropout(dropout) |
|
self._init() |
|
|
|
def _init(self): |
|
""" |
|
Overview: |
|
Initialize the parameters of the LSTM cell. |
|
""" |
|
|
|
gain = math.sqrt(1. / self.hidden_size) |
|
for l in range(self.num_layers): |
|
torch.nn.init.uniform_(self.wx[l], -gain, gain) |
|
torch.nn.init.uniform_(self.wh[l], -gain, gain) |
|
if self.bias is not None: |
|
torch.nn.init.uniform_(self.bias[l], -gain, gain) |
|
|
|
def forward(self, |
|
inputs: torch.Tensor, |
|
prev_state: torch.Tensor, |
|
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: |
|
""" |
|
Overview: |
|
Compute output and next state given previous state and input. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size]. |
|
- prev_state (:obj:`torch.Tensor`): Previous state, \ |
|
size [num_directions*num_layers, batch_size, hidden_size]. |
|
- list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True. |
|
Returns: |
|
- x (:obj:`torch.Tensor`): Output from LSTM. |
|
- next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM. |
|
""" |
|
seq_len, batch_size = inputs.shape[:2] |
|
prev_state = self._before_forward(inputs, prev_state) |
|
|
|
H, C = prev_state |
|
x = inputs |
|
next_state = [] |
|
for l in range(self.num_layers): |
|
h, c = H[l], C[l] |
|
new_x = [] |
|
for s in range(seq_len): |
|
gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) |
|
) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) |
|
if self.bias is not None: |
|
gate += self.bias[l] |
|
gate = list(torch.chunk(gate, 4, dim=1)) |
|
i, f, o, u = gate |
|
i = torch.sigmoid(i) |
|
f = torch.sigmoid(f) |
|
o = torch.sigmoid(o) |
|
u = torch.tanh(u) |
|
c = f * c + i * u |
|
h = o * torch.tanh(c) |
|
new_x.append(h) |
|
next_state.append((h, c)) |
|
x = torch.stack(new_x, dim=0) |
|
if self.use_dropout and l != self.num_layers - 1: |
|
x = self.dropout(x) |
|
next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] |
|
|
|
next_state = self._after_forward(next_state, list_next_state) |
|
return x, next_state |
|
|
|
|
|
class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): |
|
""" |
|
Overview: |
|
Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM, |
|
refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM |
|
Interfaces: |
|
``forward`` |
|
""" |
|
|
|
def forward(self, |
|
inputs: torch.Tensor, |
|
prev_state: torch.Tensor, |
|
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: |
|
""" |
|
Overview: |
|
Executes nn.LSTM.forward with preprocessed input. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size]. |
|
- prev_state (:obj:`torch.Tensor`): Previous state, size [num_directions*num_layers, batch_size, \ |
|
hidden_size]. |
|
- list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True. |
|
Returns: |
|
- output (:obj:`torch.Tensor`): Output from LSTM. |
|
- next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM. |
|
""" |
|
prev_state = self._before_forward(inputs, prev_state) |
|
output, next_state = nn.LSTM.forward(self, inputs, prev_state) |
|
next_state = self._after_forward(next_state, list_next_state) |
|
return output, next_state |
|
|
|
|
|
class GRU(nn.GRUCell, LSTMForwardWrapper): |
|
""" |
|
Overview: |
|
This class extends the `torch.nn.GRUCell` and `LSTMForwardWrapper` classes, and formats inputs and outputs |
|
accordingly. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
Properties: |
|
hidden_size, num_layers |
|
|
|
.. note:: |
|
For further details, refer to the official PyTorch documentation: |
|
<https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU> |
|
""" |
|
|
|
def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None: |
|
""" |
|
Overview: |
|
Initialize the GRU class with input size, hidden size, and number of layers. |
|
Arguments: |
|
- input_size (:obj:`int`): The size of the input vector. |
|
- hidden_size (:obj:`int`): The size of the hidden state vector. |
|
- num_layers (:obj:`int`): The number of GRU layers. |
|
""" |
|
super(GRU, self).__init__(input_size, hidden_size) |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
|
|
def forward(self, |
|
inputs: torch.Tensor, |
|
prev_state: Optional[torch.Tensor] = None, |
|
list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, List]]: |
|
""" |
|
Overview: |
|
Wrap the `nn.GRU.forward` method. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Input vector of cell, tensor of size [seq_len, batch_size, input_size]. |
|
- prev_state (:obj:`Optional[torch.Tensor]`): None or tensor of \ |
|
size [num_directions*num_layers, batch_size, hidden_size]. |
|
- list_next_state (:obj:`bool`): Whether to return next_state in list format (default is True). |
|
Returns: |
|
- output (:obj:`torch.Tensor`): Output from GRU. |
|
- next_state (:obj:`torch.Tensor` or :obj:`list`): Hidden state from GRU. |
|
""" |
|
|
|
prev_state, _ = self._before_forward(inputs, prev_state) |
|
inputs, prev_state = inputs.squeeze(0), prev_state.squeeze(0) |
|
next_state = nn.GRUCell.forward(self, inputs, prev_state) |
|
next_state = next_state.unsqueeze(0) |
|
x = next_state |
|
|
|
next_state = self._after_forward([next_state, next_state.clone()], list_next_state) |
|
return x, next_state |
|
|
|
|
|
def get_lstm( |
|
lstm_type: str, |
|
input_size: int, |
|
hidden_size: int, |
|
num_layers: int = 1, |
|
norm_type: str = 'LN', |
|
dropout: float = 0., |
|
seq_len: Optional[int] = None, |
|
batch_size: Optional[int] = None |
|
) -> Union[LSTM, PytorchLSTM]: |
|
""" |
|
Overview: |
|
Build and return the corresponding LSTM cell based on the provided parameters. |
|
Arguments: |
|
- lstm_type (:obj:`str`): Version of RNN cell. Supported options are ['normal', 'pytorch', 'hpc', 'gru']. |
|
- input_size (:obj:`int`): Size of the input vector. |
|
- hidden_size (:obj:`int`): Size of the hidden state vector. |
|
- num_layers (:obj:`int`): Number of LSTM layers (default is 1). |
|
- norm_type (:obj:`str`): Type of normalization (default is 'LN'). |
|
- dropout (:obj:`float`): Dropout rate (default is 0.0). |
|
- seq_len (:obj:`Optional[int]`): Sequence length (default is None). |
|
- batch_size (:obj:`Optional[int]`): Batch size (default is None). |
|
Returns: |
|
- lstm (:obj:`Union[LSTM, PytorchLSTM]`): The corresponding LSTM cell. |
|
""" |
|
assert lstm_type in ['normal', 'pytorch', 'hpc', 'gru'] |
|
if lstm_type == 'normal': |
|
return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) |
|
elif lstm_type == 'pytorch': |
|
return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) |
|
elif lstm_type == 'hpc': |
|
return HPCLSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout).cuda() |
|
elif lstm_type == 'gru': |
|
assert num_layers == 1 |
|
return GRU(input_size, hidden_size, num_layers) |
|
|