File size: 9,132 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from typing import Union, Optional, Dict
from easydict import EasyDict

import torch
import torch.nn as nn
from ding.model.common import ReparameterizationHead, EnsembleHead
from ding.utils import SequenceType, squeeze

from ding.utils import MODEL_REGISTRY


@MODEL_REGISTRY.register('edac')
class EDAC(nn.Module):
    """
    Overview:
        The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC.
    Interfaces:
        ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
    """
    mode = ['compute_actor', 'compute_critic']

    def __init__(
            self,
            obs_shape: Union[int, SequenceType],
            action_shape: Union[int, SequenceType, EasyDict],
            ensemble_num: int = 2,
            actor_head_hidden_size: int = 64,
            actor_head_layer_num: int = 1,
            critic_head_hidden_size: int = 64,
            critic_head_layer_num: int = 1,
            activation: Optional[nn.Module] = nn.ReLU(),
            norm_type: Optional[str] = None,
            **kwargs
    ) -> None:
        """
        Overview:
            Initailize the EDAC Model according to input arguments.
        Arguments:
            - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
            - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
                EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
            - ensemble_num (:obj:`int`): Q-net number.
            - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
            - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
                for actor head.
            - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
            - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
                for critic head.
            - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
                after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
            - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
                see ``ding.torch_utils.network`` for more details.
        """
        super(EDAC, self).__init__()
        obs_shape: int = squeeze(obs_shape)
        action_shape = squeeze(action_shape)
        self.action_shape = action_shape
        self.ensemble_num = ensemble_num
        self.actor = nn.Sequential(
            nn.Linear(obs_shape, actor_head_hidden_size), activation,
            ReparameterizationHead(
                actor_head_hidden_size,
                action_shape,
                actor_head_layer_num,
                sigma_type='conditioned',
                activation=activation,
                norm_type=norm_type
            )
        )

        critic_input_size = obs_shape + action_shape
        self.critic = EnsembleHead(
            critic_input_size,
            1,
            critic_head_hidden_size,
            critic_head_layer_num,
            self.ensemble_num,
            activation=activation,
            norm_type=norm_type
        )

    def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
        """
        Overview:
            The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \
            different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC.
        Mode compute_actor:
            Arguments:
                - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
            Returns:
                - output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
        Mode compute_critic:
            Arguments:
                - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
            Returns:
                - output (:obj:`Dict`): Output dict data, including q_value tensor.

        .. note::
            For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
        """
        assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
        return getattr(self, mode)(inputs)

    def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """
        Overview:
            The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
            such as ``action``, ``logit`` and so on.
        Arguments:
            - obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \
                i.e. ``(B, obs_shape)``.
        Returns:
            - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \
                from action_space: ``reparameterization``.
        ReturnsKeys (either):
                - logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC.
                    - mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution.
                    - sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
            - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
            - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
            - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
            - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
                ``action_shape.action_type_shape``.
            - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
                ``action_shape.action_args_shape``.
        Examples:
            >>> model = EDAC(64, 64,)
            >>> obs = torch.randn(4, 64)
            >>> actor_outputs = model(obs,'compute_actor')
            >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64])  # mu
            >>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
        """
        x = self.actor(obs)
        return {'logit': [x['mu'], x['sigma']]}

    def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Overview:
            The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
            output, such as ``q_value``.
        Arguments:
            - inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \
                  ``action`` tensor
        Returns:
            - outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``.
        ArgumentsKeys:
            - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
            - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
        ReturnKeys:
            - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \
                  ``obs_shape``.
            - action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \
                  is ``action_shape``.
            - q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size.
        Examples:
            >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
            >>> model = EDAC(obs_shape=(8, ),action_shape=1)
            >>> model(inputs, mode='compute_critic')['q_value']  # q value
            ... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=<SqueezeBackward1>)
        """

        obs, action = inputs['obs'], inputs['action']
        if len(action.shape) == 1:  # (B, ) -> (B, 1)
            action = action.unsqueeze(1)
        x = torch.cat([obs, action], dim=-1)
        if len(obs.shape) < 3:
            # [batch_size,dim] -> [batch_size,Ensemble_num * dim,1]
            x = x.repeat(1, self.ensemble_num).unsqueeze(-1)
        else:
            # [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1]
            x = x.transpose(0, 1)
            batch_size = obs.shape[1]
            x = x.reshape(batch_size, -1, 1)
        # [Ensemble_num,batch_size,1]
        x = self.critic(x)['pred']
        # [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size]
        x = x.permute(1, 0)
        return {'q_value': x}