|
from typing import Union, Optional, Dict |
|
from easydict import EasyDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ding.torch_utils import get_lstm |
|
from ding.utils import MODEL_REGISTRY, SequenceType, squeeze |
|
from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, RegressionHead |
|
|
|
|
|
@MODEL_REGISTRY.register('pdqn') |
|
class PDQN(nn.Module): |
|
""" |
|
Overview: |
|
The neural network and computation graph of PDQN(https://arxiv.org/abs/1810.06394v1) and \ |
|
MPDQN(https://arxiv.org/abs/1905.04388) algorithms for parameterized action space. \ |
|
This model supports parameterized action space with discrete ``action_type`` and continuous ``action_arg``. \ |
|
In principle, PDQN consists of x network (continuous action parameter network) and Q network (discrete \ |
|
action type network). But for simplicity, the code is split into ``encoder`` and ``actor_head``, which \ |
|
contain the encoder and head of the above two networks respectively. |
|
Interface: |
|
``__init__``, ``forward``, ``compute_discrete``, ``compute_continuous``. |
|
""" |
|
mode = ['compute_discrete', 'compute_continuous'] |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: EasyDict, |
|
encoder_hidden_size_list: SequenceType = [128, 128, 64], |
|
dueling: bool = True, |
|
head_hidden_size: Optional[int] = None, |
|
head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
multi_pass: Optional[bool] = False, |
|
action_mask: Optional[list] = None |
|
) -> None: |
|
""" |
|
Overview: |
|
Init the PDQN (encoder + head) Model according to input arguments. |
|
Arguments: |
|
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. |
|
- action_shape (:obj:`EasyDict`): Action space shape in dict type, such as \ |
|
EasyDict({'action_type_shape': 3, 'action_args_shape': 5}). |
|
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \ |
|
the last element must match ``head_hidden_size``. |
|
- dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``. |
|
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. |
|
- head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output. |
|
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ |
|
if ``None`` then default set it to ``nn.ReLU()``. |
|
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ |
|
``ding.torch_utils.fc_block`` for more details. |
|
- multi_pass (:obj:`Optional[bool]`): Whether to use multi pass version. |
|
- action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are \ |
|
associated to each discrete action. For example, if there are 3 discrete action, \ |
|
4 continous action args, and the first discrete action associates with the first \ |
|
continuous action args, the second discrete action associates with the second continuous \ |
|
action args, and the third discrete action associates with the remaining 2 action args, \ |
|
the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4. |
|
""" |
|
super(PDQN, self).__init__() |
|
self.multi_pass = multi_pass |
|
if self.multi_pass: |
|
assert isinstance( |
|
action_mask, list |
|
), 'Please indicate action mask in list form if you set multi_pass to True' |
|
self.action_mask = torch.LongTensor(action_mask) |
|
nonzero = torch.nonzero(self.action_mask) |
|
index = torch.zeros(action_shape.action_args_shape).long() |
|
index.scatter_(dim=0, index=nonzero[:, 1], src=nonzero[:, 0]) |
|
self.action_scatter_index = index |
|
|
|
|
|
action_shape.action_args_shape = squeeze(action_shape.action_args_shape) |
|
action_shape.action_type_shape = squeeze(action_shape.action_type_shape) |
|
self.action_args_shape = action_shape.action_args_shape |
|
self.action_type_shape = action_shape.action_type_shape |
|
|
|
|
|
if head_hidden_size is None: |
|
head_hidden_size = encoder_hidden_size_list[-1] |
|
|
|
|
|
obs_shape = squeeze(obs_shape) |
|
|
|
|
|
if isinstance(obs_shape, int) or len(obs_shape) == 1: |
|
self.dis_encoder = FCEncoder( |
|
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type |
|
) |
|
self.cont_encoder = FCEncoder( |
|
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type |
|
) |
|
elif len(obs_shape) == 3: |
|
self.dis_encoder = ConvEncoder( |
|
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type |
|
) |
|
self.cont_encoder = ConvEncoder( |
|
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type |
|
) |
|
else: |
|
raise RuntimeError( |
|
"Pre-defined encoder not support obs_shape {}, please customize your own PDQN.".format(obs_shape) |
|
) |
|
|
|
|
|
self.cont_head = RegressionHead( |
|
head_hidden_size, |
|
action_shape.action_args_shape, |
|
head_layer_num, |
|
final_tanh=True, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
|
|
|
|
if dueling: |
|
dis_head_cls = DuelingHead |
|
else: |
|
dis_head_cls = DiscreteHead |
|
self.dis_head = dis_head_cls( |
|
head_hidden_size + action_shape.action_args_shape, |
|
action_shape.action_type_shape, |
|
head_layer_num, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
|
|
self.actor_head = nn.ModuleList([self.dis_head, self.cont_head]) |
|
|
|
|
|
self.encoder = nn.ModuleList([self.cont_encoder, self.cont_encoder]) |
|
|
|
def forward(self, inputs: Union[torch.Tensor, Dict, EasyDict], mode: str) -> Dict: |
|
""" |
|
Overview: |
|
PDQN forward computation graph, input observation tensor to predict q_value for \ |
|
discrete actions and values for continuous action_args. |
|
Arguments: |
|
- inputs (:obj:`Union[torch.Tensor, Dict, EasyDict]`): Inputs including observation and \ |
|
other info according to `mode`. |
|
- mode (:obj:`str`): Name of the forward mode. |
|
Shapes: |
|
- inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``. |
|
""" |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs) |
|
|
|
def compute_continuous(self, inputs: torch.Tensor) -> Dict: |
|
""" |
|
Overview: |
|
Use observation tensor to predict continuous action args. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): Observation inputs. |
|
Returns: |
|
- outputs (:obj:`Dict`): A dict with key 'action_args'. |
|
- 'action_args' (:obj:`torch.Tensor`): The continuous action args. |
|
Shapes: |
|
- inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``. |
|
- action_args (:obj:`torch.Tensor`): :math:`(B, M)`, where M is ``action_args_shape``. |
|
Examples: |
|
>>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )}) |
|
>>> model = PDQN(4, act_shape) |
|
>>> inputs = torch.randn(64, 4) |
|
>>> outputs = model.forward(inputs, mode='compute_continuous') |
|
>>> assert outputs['action_args'].shape == torch.Size([64, 5]) |
|
""" |
|
cont_x = self.encoder[1](inputs) |
|
action_args = self.actor_head[1](cont_x)['pred'] |
|
outputs = {'action_args': action_args} |
|
return outputs |
|
|
|
def compute_discrete(self, inputs: Union[Dict, EasyDict]) -> Dict: |
|
""" |
|
Overview: |
|
Use observation tensor and continuous action args to predict discrete action types. |
|
Arguments: |
|
- inputs (:obj:`Union[Dict, EasyDict]`): A dict with keys 'state', 'action_args'. |
|
- state (:obj:`torch.Tensor`): Observation inputs. |
|
- action_args (:obj:`torch.Tensor`): Action parameters are used to concatenate with the observation \ |
|
and serve as input to the discrete action type network. |
|
Returns: |
|
- outputs (:obj:`Dict`): A dict with keys 'logit', 'action_args'. |
|
- 'logit': The logit value for each discrete action. |
|
- 'action_args': The continuous action args(same as the inputs['action_args']) for later usage. |
|
Examples: |
|
>>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )}) |
|
>>> model = PDQN(4, act_shape) |
|
>>> inputs = {'state': torch.randn(64, 4), 'action_args': torch.randn(64, 5)} |
|
>>> outputs = model.forward(inputs, mode='compute_discrete') |
|
>>> assert outputs['logit'].shape == torch.Size([64, 3]) |
|
>>> assert outputs['action_args'].shape == torch.Size([64, 5]) |
|
""" |
|
dis_x = self.encoder[0](inputs['state']) |
|
action_args = inputs['action_args'] |
|
|
|
if self.multi_pass: |
|
|
|
|
|
mp_action = torch.full( |
|
(dis_x.shape[0], self.action_args_shape, self.action_type_shape), |
|
fill_value=-2, |
|
device=dis_x.device, |
|
dtype=dis_x.dtype |
|
) |
|
index = self.action_scatter_index.view(1, -1, 1).repeat(dis_x.shape[0], 1, 1).to(dis_x.device) |
|
|
|
|
|
mp_action.scatter_(dim=-1, index=index, src=action_args.unsqueeze(-1)) |
|
mp_action = mp_action.permute(0, 2, 1) |
|
|
|
mp_state = dis_x.unsqueeze(1).repeat(1, self.action_type_shape, 1) |
|
mp_state_action_cat = torch.cat([mp_state, mp_action], dim=-1) |
|
|
|
logit = self.actor_head[0](mp_state_action_cat)['logit'] |
|
|
|
logit = torch.diagonal(logit, dim1=-2, dim2=-1) |
|
else: |
|
|
|
if len(action_args.shape) == 1: |
|
action_args = action_args.unsqueeze(1) |
|
state_action_cat = torch.cat((dis_x, action_args), dim=-1) |
|
logit = self.actor_head[0](state_action_cat)['logit'] |
|
|
|
outputs = {'logit': logit, 'action_args': action_args} |
|
return outputs |
|
|