File size: 9,132 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from typing import Union, Optional, Dict
from easydict import EasyDict
import torch
import torch.nn as nn
from ding.model.common import ReparameterizationHead, EnsembleHead
from ding.utils import SequenceType, squeeze
from ding.utils import MODEL_REGISTRY
@MODEL_REGISTRY.register('edac')
class EDAC(nn.Module):
"""
Overview:
The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC.
Interfaces:
``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
"""
mode = ['compute_actor', 'compute_critic']
def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType, EasyDict],
ensemble_num: int = 2,
actor_head_hidden_size: int = 64,
actor_head_layer_num: int = 1,
critic_head_hidden_size: int = 64,
critic_head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
**kwargs
) -> None:
"""
Overview:
Initailize the EDAC Model according to input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
- action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
- ensemble_num (:obj:`int`): Q-net number.
- actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
- actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for actor head.
- critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
- critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
for critic head.
- activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
see ``ding.torch_utils.network`` for more details.
"""
super(EDAC, self).__init__()
obs_shape: int = squeeze(obs_shape)
action_shape = squeeze(action_shape)
self.action_shape = action_shape
self.ensemble_num = ensemble_num
self.actor = nn.Sequential(
nn.Linear(obs_shape, actor_head_hidden_size), activation,
ReparameterizationHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
sigma_type='conditioned',
activation=activation,
norm_type=norm_type
)
)
critic_input_size = obs_shape + action_shape
self.critic = EnsembleHead(
critic_input_size,
1,
critic_head_hidden_size,
critic_head_layer_num,
self.ensemble_num,
activation=activation,
norm_type=norm_type
)
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
"""
Overview:
The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \
different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC.
Mode compute_actor:
Arguments:
- inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
Mode compute_critic:
Arguments:
- inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
Returns:
- output (:obj:`Dict`): Output dict data, including q_value tensor.
.. note::
For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
"""
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(inputs)
def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Overview:
The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
such as ``action``, ``logit`` and so on.
Arguments:
- obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \
i.e. ``(B, obs_shape)``.
Returns:
- outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \
from action_space: ``reparameterization``.
ReturnsKeys (either):
- logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC.
- mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution.
- sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
- action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
- logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
- logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
- logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
``action_shape.action_type_shape``.
- action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
``action_shape.action_args_shape``.
Examples:
>>> model = EDAC(64, 64,)
>>> obs = torch.randn(4, 64)
>>> actor_outputs = model(obs,'compute_actor')
>>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu
>>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
"""
x = self.actor(obs)
return {'logit': [x['mu'], x['sigma']]}
def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
output, such as ``q_value``.
Arguments:
- inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \
``action`` tensor
Returns:
- outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``.
ArgumentsKeys:
- obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
- action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
ReturnKeys:
- q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \
``obs_shape``.
- action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \
is ``action_shape``.
- q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size.
Examples:
>>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
>>> model = EDAC(obs_shape=(8, ),action_shape=1)
>>> model(inputs, mode='compute_critic')['q_value'] # q value
... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
"""
obs, action = inputs['obs'], inputs['action']
if len(action.shape) == 1: # (B, ) -> (B, 1)
action = action.unsqueeze(1)
x = torch.cat([obs, action], dim=-1)
if len(obs.shape) < 3:
# [batch_size,dim] -> [batch_size,Ensemble_num * dim,1]
x = x.repeat(1, self.ensemble_num).unsqueeze(-1)
else:
# [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1]
x = x.transpose(0, 1)
batch_size = obs.shape[1]
x = x.reshape(batch_size, -1, 1)
# [Ensemble_num,batch_size,1]
x = self.critic(x)['pred']
# [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size]
x = x.permute(1, 0)
return {'q_value': x}
|