"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE""" import torch from torch.nn import functional as F from torch import nn from abc import abstractmethod from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional from ding.utils.type_helper import Tensor class VanillaVAE(nn.Module): """ Overview: Implementation of Vanilla variational autoencoder for action reconstruction. Interfaces: ``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \ ``forward``, ``loss_function`` . """ def __init__( self, action_shape: int, obs_shape: int, latent_size: int, hidden_dims: List = [256, 256], **kwargs ) -> None: super(VanillaVAE, self).__init__() self.action_shape = action_shape self.obs_shape = obs_shape self.latent_size = latent_size self.hidden_dims = hidden_dims # Build Encoder self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU()) self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU()) self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU()) self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size) self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size) # Build Decoder self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU()) self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU()) # TODO(pu): tanh self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh()) # residual prediction self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU()) self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape) self.obs_encoding = None def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]: """ Overview: Encodes the input by passing through the encoder network and returns the latent codes. Arguments: - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \ `action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \ ``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \ representing latent codes. Shapes: - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. - action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``. """ action_encoding = self.encode_action_head(input['action']) obs_encoding = self.encode_obs_head(input['obs']) # obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network input = obs_encoding * action_encoding # TODO(pu): what about add, cat? result = self.encode_common(input) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.encode_mu_head(result) log_var = self.encode_logvar_head(result) return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding} def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: """ Overview: Maps the given latent action and obs_encoding onto the original action space. Arguments: - z (:obj:`torch.Tensor`): the sampled latent action - obs_encoding (:obj:`torch.Tensor`): observation encoding Returns: - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. ReturnsKeys: - reconstruction_action (:obj:`torch.Tensor`): reconstruction_action. - predition_residual (:obj:`torch.Tensor`): predition_residual. Shapes: - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim`` """ action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded action_obs_decoding = action_decoding * obs_encoding action_obs_decoding_tmp = self.decode_common(action_obs_decoding) reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: """ Overview: Maps the given latent action and obs onto the original action space. Using the method self.encode_obs_head(obs) to get the obs_encoding. Arguments: - z (:obj:`torch.Tensor`): the sampled latent action - obs (:obj:`torch.Tensor`): observation Returns: - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. ReturnsKeys: - reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE . - predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE. Shapes: - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape`` """ obs_encoding = self.encode_obs_head(obs) # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh action_decoding = self.decode_action_head(z) action_obs_decoding = action_decoding * obs_encoding action_obs_decoding_tmp = self.decode_common(action_obs_decoding) reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Overview: Reparameterization trick to sample from N(mu, var) from N(0,1). Arguments: - mu (:obj:`torch.Tensor`): Mean of the latent Gaussian - logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian Shapes: - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` - logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Dict[str, Tensor], **kwargs) -> dict: """ Overview: Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`. Argumens: - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \ and `action` (:obj:`torch.Tensor`), representing the observation \ and agent's action respectively. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). Shapes: - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \ where B is batch size and O is ``observation dim``. - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` """ encode_output = self.encode(input) z = self.reparameterize(encode_output['mu'], encode_output['log_var']) decode_output = self.decode(z, encode_output['obs_encoding']) return { 'recons_action': decode_output['reconstruction_action'], 'prediction_residual': decode_output['predition_residual'], 'input': input, 'mu': encode_output['mu'], 'log_var': encode_output['log_var'], 'z': z } def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: """ Overview: Computes the VAE loss function. Arguments: - args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \ ``original_action``, ``mu``, ``log_var`` and ``true_residual``. - kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``. Returns: - outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \ ``reconstruction_loss``, ``kld_loss``, ``predict_loss``. Shapes: - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \ and A is ``action dim``. - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \ and O is ``observation dim``. - original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. - true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. """ recons_action = args['recons_action'] prediction_residual = args['prediction_residual'] original_action = args['original_action'] mu = args['mu'] log_var = args['log_var'] true_residual = args['true_residual'] kld_weight = kwargs['kld_weight'] predict_weight = kwargs['predict_weight'] recons_loss = F.mse_loss(recons_action, original_action) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) predict_loss = F.mse_loss(prediction_residual, true_residual) loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}