File size: 11,962 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE"""

import torch
from torch.nn import functional as F
from torch import nn
from abc import abstractmethod
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional
from ding.utils.type_helper import Tensor


class VanillaVAE(nn.Module):
    """
        Overview:
            Implementation of Vanilla variational autoencoder for action reconstruction.
        Interfaces:
            ``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \
                ``forward``, ``loss_function`` .
    """

    def __init__(
            self,
            action_shape: int,
            obs_shape: int,
            latent_size: int,
            hidden_dims: List = [256, 256],
            **kwargs
    ) -> None:
        super(VanillaVAE, self).__init__()
        self.action_shape = action_shape
        self.obs_shape = obs_shape
        self.latent_size = latent_size
        self.hidden_dims = hidden_dims

        # Build Encoder
        self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU())
        self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU())

        self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU())
        self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size)
        self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size)

        # Build Decoder
        self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU())
        self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU())
        # TODO(pu): tanh
        self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh())

        # residual prediction
        self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU())
        self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape)

        self.obs_encoding = None

    def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]:
        """
        Overview:
            Encodes the input by passing through the encoder network and returns the latent codes.
        Arguments:
            - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \
                `action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively.
        Returns:
            - outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \
                ``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \
                representing latent codes.
        Shapes:
            - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
            - action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
            - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
            - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
            - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``.
        """
        action_encoding = self.encode_action_head(input['action'])
        obs_encoding = self.encode_obs_head(input['obs'])
        # obs_encoding = self.condition_obs(input['obs'])  #  TODO(pu): using a different network
        input = obs_encoding * action_encoding  # TODO(pu): what about add, cat?
        result = self.encode_common(input)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.encode_mu_head(result)
        log_var = self.encode_logvar_head(result)

        return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding}

    def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]:
        """
         Overview:
               Maps the given latent action and obs_encoding onto the original action space.
         Arguments:
             - z (:obj:`torch.Tensor`): the sampled latent action
             - obs_encoding (:obj:`torch.Tensor`): observation encoding
         Returns:
             - outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
         ReturnsKeys:
             - reconstruction_action (:obj:`torch.Tensor`): reconstruction_action.
             - predition_residual (:obj:`torch.Tensor`): predition_residual.
         Shapes:
             - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
             - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``
        """
        action_decoding = self.decode_action_head(torch.tanh(z))  # NOTE: tanh, here z is not bounded
        action_obs_decoding = action_decoding * obs_encoding
        action_obs_decoding_tmp = self.decode_common(action_obs_decoding)

        reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
        predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
        predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
        return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual}

    def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]:
        """
          Overview:
                Maps the given latent action and obs onto the original action space.
                Using the method self.encode_obs_head(obs) to get the obs_encoding.
          Arguments:
              - z (:obj:`torch.Tensor`): the sampled latent action
              - obs (:obj:`torch.Tensor`): observation
          Returns:
              - outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
          ReturnsKeys:
              - reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE .
              - predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE.
          Shapes:
              - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
              - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape``
        """
        obs_encoding = self.encode_obs_head(obs)
        # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
        action_decoding = self.decode_action_head(z)
        action_obs_decoding = action_decoding * obs_encoding
        action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
        reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
        predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
        predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)

        return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual}

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
         Overview:
              Reparameterization trick to sample from N(mu, var) from N(0,1).
         Arguments:
             - mu (:obj:`torch.Tensor`): Mean of the latent Gaussian
             - logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian
         Shapes:
             - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size``
             - logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size``
         """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Dict[str, Tensor], **kwargs) -> dict:
        """
        Overview:
            Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`.
        Argumens:
            - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \
                and `action` (:obj:`torch.Tensor`), representing the observation \
                and agent's action respectively.
        Returns:
            - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
                (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
                ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
                ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
        Shapes:
            - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
            - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \
                where B is batch size and O is ``observation dim``.
            - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
            - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
            - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
        """

        encode_output = self.encode(input)
        z = self.reparameterize(encode_output['mu'], encode_output['log_var'])
        decode_output = self.decode(z, encode_output['obs_encoding'])
        return {
            'recons_action': decode_output['reconstruction_action'],
            'prediction_residual': decode_output['predition_residual'],
            'input': input,
            'mu': encode_output['mu'],
            'log_var': encode_output['log_var'],
            'z': z
        }

    def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
        """
        Overview:
            Computes the VAE loss function.
        Arguments:
            - args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \
                ``original_action``, ``mu``, ``log_var`` and ``true_residual``.
            - kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``.
        Returns:
            - outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \
                ``reconstruction_loss``, ``kld_loss``, ``predict_loss``.
        Shapes:
            - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \
                and A is ``action dim``.
            - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \
                and O is ``observation dim``.
            - original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
            - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
            - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
            - true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
        """
        recons_action = args['recons_action']
        prediction_residual = args['prediction_residual']
        original_action = args['original_action']
        mu = args['mu']
        log_var = args['log_var']
        true_residual = args['true_residual']

        kld_weight = kwargs['kld_weight']
        predict_weight = kwargs['predict_weight']

        recons_loss = F.mse_loss(recons_action, original_action)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
        predict_loss = F.mse_loss(prediction_residual, true_residual)

        loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss
        return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}