|
from typing import Union, Dict, Optional |
|
import torch |
|
import torch.nn as nn |
|
|
|
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY |
|
from ..common import RegressionHead, ReparameterizationHead, DistributionHead |
|
|
|
|
|
@MODEL_REGISTRY.register('qac_dist') |
|
class QACDIST(nn.Module): |
|
""" |
|
Overview: |
|
The QAC model with distributional Q-value. |
|
Interfaces: |
|
``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` |
|
""" |
|
mode = ['compute_actor', 'compute_critic'] |
|
|
|
def __init__( |
|
self, |
|
obs_shape: Union[int, SequenceType], |
|
action_shape: Union[int, SequenceType], |
|
action_space: str = "regression", |
|
critic_head_type: str = "categorical", |
|
actor_head_hidden_size: int = 64, |
|
actor_head_layer_num: int = 1, |
|
critic_head_hidden_size: int = 64, |
|
critic_head_layer_num: int = 1, |
|
activation: Optional[nn.Module] = nn.ReLU(), |
|
norm_type: Optional[str] = None, |
|
v_min: Optional[float] = -10, |
|
v_max: Optional[float] = 10, |
|
n_atom: Optional[int] = 51, |
|
) -> None: |
|
""" |
|
Overview: |
|
Init the QAC Distributional Model according to arguments. |
|
Arguments: |
|
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. |
|
- action_shape (:obj:`Union[int, SequenceType]`): Action's space. |
|
- action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``. |
|
- critic_head_type (:obj:`str`): Only ``categorical``. |
|
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. |
|
- actor_head_layer_num (:obj:`int`): |
|
The num of layers used in the network to compute Q value output for actor's nn. |
|
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. |
|
- critic_head_layer_num (:obj:`int`): |
|
The num of layers used in the network to compute Q value output for critic's nn. |
|
- activation (:obj:`Optional[nn.Module]`): |
|
The type of activation function to use in ``MLP`` the after ``layer_fn``, |
|
if ``None`` then default set to ``nn.ReLU()`` |
|
- norm_type (:obj:`Optional[str]`): |
|
The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details. |
|
- v_min (:obj:`int`): Value of the smallest atom |
|
- v_max (:obj:`int`): Value of the largest atom |
|
- n_atom (:obj:`int`): Number of atoms in the support |
|
""" |
|
super(QACDIST, self).__init__() |
|
obs_shape: int = squeeze(obs_shape) |
|
action_shape: int = squeeze(action_shape) |
|
self.action_space = action_space |
|
assert self.action_space in ['regression', 'reparameterization'] |
|
if self.action_space == 'regression': |
|
self.actor = nn.Sequential( |
|
nn.Linear(obs_shape, actor_head_hidden_size), activation, |
|
RegressionHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
final_tanh=True, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
) |
|
elif self.action_space == 'reparameterization': |
|
self.actor = nn.Sequential( |
|
nn.Linear(obs_shape, actor_head_hidden_size), activation, |
|
ReparameterizationHead( |
|
actor_head_hidden_size, |
|
action_shape, |
|
actor_head_layer_num, |
|
sigma_type='conditioned', |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
) |
|
self.critic_head_type = critic_head_type |
|
assert self.critic_head_type in ['categorical'], self.critic_head_type |
|
if self.critic_head_type == 'categorical': |
|
self.critic = nn.Sequential( |
|
nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation, |
|
DistributionHead( |
|
critic_head_hidden_size, |
|
1, |
|
critic_head_layer_num, |
|
n_atom=n_atom, |
|
v_min=v_min, |
|
v_max=v_max, |
|
activation=activation, |
|
norm_type=norm_type |
|
) |
|
) |
|
|
|
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: |
|
""" |
|
Overview: |
|
Use observation and action tensor to predict output. |
|
Parameter updates with QACDIST's MLPs forward setup. |
|
Arguments: |
|
Forward with ``'compute_actor'``: |
|
- inputs (:obj:`torch.Tensor`): |
|
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. |
|
Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. |
|
|
|
Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys: |
|
- ``obs``, ``action`` encoded tensors. |
|
|
|
- mode (:obj:`str`): Name of the forward mode. |
|
Returns: |
|
- outputs (:obj:`Dict`): Outputs of network forward. |
|
|
|
Forward with ``'compute_actor'``, Necessary Keys (either): |
|
- action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``. |
|
- logit (:obj:`torch.Tensor`): |
|
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. |
|
|
|
Forward with ``'compute_critic'``, Necessary Keys: |
|
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
|
- distribution (:obj:`torch.Tensor`): Q value distribution tensor. |
|
Actor Shapes: |
|
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` |
|
- action (:obj:`torch.Tensor`): :math:`(B, N0)` |
|
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. |
|
|
|
Critic Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` |
|
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape`` |
|
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` |
|
- distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom`` |
|
|
|
Actor Examples: |
|
>>> # Regression mode |
|
>>> model = QACDIST(64, 64, 'regression') |
|
>>> inputs = torch.randn(4, 64) |
|
>>> actor_outputs = model(inputs,'compute_actor') |
|
>>> assert actor_outputs['action'].shape == torch.Size([4, 64]) |
|
>>> # Reparameterization Mode |
|
>>> model = QACDIST(64, 64, 'reparameterization') |
|
>>> inputs = torch.randn(4, 64) |
|
>>> actor_outputs = model(inputs,'compute_actor') |
|
>>> actor_outputs['logit'][0].shape # mu |
|
>>> torch.Size([4, 64]) |
|
>>> actor_outputs['logit'][1].shape # sigma |
|
>>> torch.Size([4, 64]) |
|
|
|
Critic Examples: |
|
>>> # Categorical mode |
|
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)} |
|
>>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \ |
|
... critic_head_type='categorical', n_atoms=51) |
|
>>> q_value = model(inputs, mode='compute_critic') # q value |
|
>>> assert q_value['q_value'].shape == torch.Size([4, 1]) |
|
>>> assert q_value['distribution'].shape == torch.Size([4, 1, 51]) |
|
""" |
|
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
|
return getattr(self, mode)(inputs) |
|
|
|
def compute_actor(self, inputs: torch.Tensor) -> Dict: |
|
""" |
|
Overview: |
|
Use encoded embedding tensor to predict output. |
|
Execute parameter updates with ``'compute_actor'`` mode |
|
Use encoded embedding tensor to predict output. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): |
|
The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. |
|
``hidden_size = actor_head_hidden_size`` |
|
- mode (:obj:`str`): Name of the forward mode. |
|
Returns: |
|
- outputs (:obj:`Dict`): Outputs of forward pass encoder and head. |
|
|
|
ReturnsKeys (either): |
|
- action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``. |
|
- logit (:obj:`torch.Tensor`): |
|
Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``. |
|
Shapes: |
|
- inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size`` |
|
- action (:obj:`torch.Tensor`): :math:`(B, N0)` |
|
- logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`. |
|
- q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size. |
|
Examples: |
|
>>> # Regression mode |
|
>>> model = QACDIST(64, 64, 'regression') |
|
>>> inputs = torch.randn(4, 64) |
|
>>> actor_outputs = model(inputs,'compute_actor') |
|
>>> assert actor_outputs['action'].shape == torch.Size([4, 64]) |
|
>>> # Reparameterization Mode |
|
>>> model = QACDIST(64, 64, 'reparameterization') |
|
>>> inputs = torch.randn(4, 64) |
|
>>> actor_outputs = model(inputs,'compute_actor') |
|
>>> actor_outputs['logit'][0].shape # mu |
|
>>> torch.Size([4, 64]) |
|
>>> actor_outputs['logit'][1].shape # sigma |
|
>>> torch.Size([4, 64]) |
|
""" |
|
x = self.actor(inputs) |
|
if self.action_space == 'regression': |
|
return {'action': x['pred']} |
|
elif self.action_space == 'reparameterization': |
|
return {'logit': [x['mu'], x['sigma']]} |
|
|
|
def compute_critic(self, inputs: Dict) -> Dict: |
|
""" |
|
Overview: |
|
Execute parameter updates with ``'compute_critic'`` mode |
|
Use encoded embedding tensor to predict output. |
|
Arguments: |
|
- ``obs``, ``action`` encoded tensors. |
|
- mode (:obj:`str`): Name of the forward mode. |
|
Returns: |
|
- outputs (:obj:`Dict`): Q-value output and distribution. |
|
|
|
ReturnKeys: |
|
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. |
|
- distribution (:obj:`torch.Tensor`): Q value distribution tensor. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape`` |
|
- action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape`` |
|
- q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape`` |
|
- distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom`` |
|
|
|
Examples: |
|
>>> # Categorical mode |
|
>>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)} |
|
>>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \ |
|
... critic_head_type='categorical', n_atoms=51) |
|
>>> q_value = model(inputs, mode='compute_critic') # q value |
|
>>> assert q_value['q_value'].shape == torch.Size([4, 1]) |
|
>>> assert q_value['distribution'].shape == torch.Size([4, 1, 51]) |
|
""" |
|
obs, action = inputs['obs'], inputs['action'] |
|
assert len(obs.shape) == 2 |
|
if len(action.shape) == 1: |
|
action = action.unsqueeze(1) |
|
x = torch.cat([obs, action], dim=1) |
|
x = self.critic(x) |
|
return {'q_value': x['logit'], 'distribution': x['distribution']} |
|
|