from typing import Union, Dict, Optional, List from easydict import EasyDict import numpy as np import torch import torch.nn as nn from ding.utils import SequenceType, squeeze, MODEL_REGISTRY from ..common import RegressionHead, ReparameterizationHead from .vae import VanillaVAE @MODEL_REGISTRY.register('bcq') class BCQ(nn.Module): """ Overview: Model of BCQ (Batch-Constrained deep Q-learning). Off-Policy Deep Reinforcement Learning without Exploration. https://arxiv.org/abs/1812.02900 Interface: ``forward``, ``compute_actor``, ``compute_critic``, ``compute_vae``, ``compute_eval`` Property: ``mode`` """ mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval'] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType, EasyDict], actor_head_hidden_size: List = [400, 300], critic_head_hidden_size: List = [400, 300], activation: Optional[nn.Module] = nn.ReLU(), vae_hidden_dims: List = [750, 750], phi: float = 0.05 ) -> None: """ Overview: Initialize neural network, i.e. agent Q network and actor. Arguments: - obs_shape (:obj:`int`): the dimension of observation state - action_shape (:obj:`int`): the dimension of action shape - actor_hidden_size (:obj:`list`): the list of hidden size of actor - critic_hidden_size (:obj:'list'): the list of hidden size of critic - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). - vae_hidden_dims (:obj:`list`): the list of hidden size of vae """ super(BCQ, self).__init__() obs_shape: int = squeeze(obs_shape) action_shape = squeeze(action_shape) self.action_shape = action_shape self.input_size = obs_shape self.phi = phi critic_input_size = self.input_size + action_shape self.critic = nn.ModuleList() for _ in range(2): net = [] d = critic_input_size for dim in critic_head_hidden_size: net.append(nn.Linear(d, dim)) net.append(activation) d = dim net.append(nn.Linear(d, 1)) self.critic.append(nn.Sequential(*net)) net = [] d = critic_input_size for dim in actor_head_hidden_size: net.append(nn.Linear(d, dim)) net.append(activation) d = dim net.append(nn.Linear(d, 1)) self.actor = nn.Sequential(*net) self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims) def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]: """ Overview: The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \ different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ. Mode compute_actor: Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - output (:obj:`Dict`): Output dict data, including action tensor. Mode compute_critic: Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - output (:obj:`Dict`): Output dict data, including q_value tensor. Mode compute_vae: Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). Mode compute_eval: Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - output (:obj:`Dict`): Output dict data, including action tensor. Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model(inputs, mode='compute_actor') >>> outputs = model(inputs, mode='compute_critic') >>> outputs = model(inputs, mode='compute_vae') >>> outputs = model(inputs, mode='compute_eval') .. note:: For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Overview: Use critic network to compute q value. Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``q_value`` (:obj:`torch.Tensor`). Shapes: - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. - outputs (:obj:`Dict`): :math:`(B, N)`. Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_critic(inputs) """ obs, action = inputs['obs'], inputs['action'] if len(action.shape) == 1: # (B, ) -> (B, 1) action = action.unsqueeze(1) x = torch.cat([obs, action], dim=-1) x = [m(x).squeeze() for m in self.critic] return {'q_value': x} def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: """ Overview: Use actor network to compute action. Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). Shapes: - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. - outputs (:obj:`Dict`): :math:`(B, N)`. Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_actor(inputs) """ input = torch.cat([inputs['obs'], inputs['action']], -1) x = self.actor(input) action = self.phi * 1 * torch.tanh(x) action = (action + inputs['action']).clamp(-1, 1) return {'action': action} def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Overview: Use vae network to compute action. Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` (:obj:`torch.Tensor`), \ ``prediction_residual`` (:obj:`torch.Tensor`), ``input`` (:obj:`torch.Tensor`), \ ``mu`` (:obj:`torch.Tensor`), ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). Shapes: - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. - outputs (:obj:`Dict`): :math:`(B, N)`. Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_vae(inputs) """ return self.vae.forward(inputs) def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Overview: Use actor network to compute action. Arguments: - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. Returns: - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). Shapes: - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. - outputs (:obj:`Dict`): :math:`(B, N)`. Examples: >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} >>> model = BCQ(32, 6) >>> outputs = model.compute_eval(inputs) """ obs = inputs['obs'] obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0) z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5) sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action'] action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action'] q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0] idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1) idx = idx.repeat_interleave(action.shape[-1], dim=-1) action = action.gather(0, idx).squeeze() return {'action': action}