from typing import Union, Optional from easydict import EasyDict import torch import torch.nn as nn import treetensor.torch as ttorch from copy import deepcopy from ding.utils import SequenceType, squeeze from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \ FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead from ding.torch_utils import MLP, fc_block class DiscretePolicyHead(nn.Module): def __init__( self, hidden_size: int, output_size: int, layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, ) -> None: super(DiscretePolicyHead, self).__init__() self.main = nn.Sequential( MLP( hidden_size, hidden_size, hidden_size, layer_num, layer_fn=nn.Linear, activation=activation, norm_type=norm_type ), fc_block(hidden_size, output_size) ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.main(x) class PPOFModel(nn.Module): mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] def __init__( self, obs_shape: Union[int, SequenceType], action_shape: Union[int, SequenceType, EasyDict], action_space: str = 'discrete', share_encoder: bool = True, encoder_hidden_size_list: SequenceType = [128, 128, 64], actor_head_hidden_size: int = 64, actor_head_layer_num: int = 1, critic_head_hidden_size: int = 64, critic_head_layer_num: int = 1, activation: Optional[nn.Module] = nn.ReLU(), norm_type: Optional[str] = None, sigma_type: Optional[str] = 'independent', fixed_sigma_value: Optional[int] = 0.3, bound_type: Optional[str] = None, encoder: Optional[torch.nn.Module] = None, popart_head=False, ) -> None: super(PPOFModel, self).__init__() obs_shape = squeeze(obs_shape) action_shape = squeeze(action_shape) self.obs_shape, self.action_shape = obs_shape, action_shape self.share_encoder = share_encoder # Encoder Type def new_encoder(outsize): if isinstance(obs_shape, int) or len(obs_shape) == 1: return FCEncoder( obs_shape=obs_shape, hidden_size_list=encoder_hidden_size_list, activation=activation, norm_type=norm_type ) elif len(obs_shape) == 3: return ConvEncoder( obs_shape=obs_shape, hidden_size_list=encoder_hidden_size_list, activation=activation, norm_type=norm_type ) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". format(obs_shape) ) if self.share_encoder: assert actor_head_hidden_size == critic_head_hidden_size, \ "actor and critic network head should have same size." if encoder: if isinstance(encoder, torch.nn.Module): self.encoder = encoder else: raise ValueError("illegal encoder instance.") else: self.encoder = new_encoder(actor_head_hidden_size) else: if encoder: if isinstance(encoder, torch.nn.Module): self.actor_encoder = encoder self.critic_encoder = deepcopy(encoder) else: raise ValueError("illegal encoder instance.") else: self.actor_encoder = new_encoder(actor_head_hidden_size) self.critic_encoder = new_encoder(critic_head_hidden_size) # Head Type if not popart_head: self.critic_head = RegressionHead( critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type ) else: self.critic_head = PopArtVHead( critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type ) self.action_space = action_space assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space if self.action_space == 'continuous': self.multi_head = False self.actor_head = ReparameterizationHead( actor_head_hidden_size, action_shape, actor_head_layer_num, sigma_type=sigma_type, activation=activation, norm_type=norm_type, bound_type=bound_type ) elif self.action_space == 'discrete': actor_head_cls = DiscretePolicyHead multi_head = not isinstance(action_shape, int) self.multi_head = multi_head if multi_head: self.actor_head = MultiHead( actor_head_cls, actor_head_hidden_size, action_shape, layer_num=actor_head_layer_num, activation=activation, norm_type=norm_type ) else: self.actor_head = actor_head_cls( actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type ) elif self.action_space == 'hybrid': # HPPO # hybrid action space: action_type(discrete) + action_args(continuous), # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} action_shape.action_args_shape = squeeze(action_shape.action_args_shape) action_shape.action_type_shape = squeeze(action_shape.action_type_shape) actor_action_args = ReparameterizationHead( actor_head_hidden_size, action_shape.action_args_shape, actor_head_layer_num, sigma_type=sigma_type, fixed_sigma_value=fixed_sigma_value, activation=activation, norm_type=norm_type, bound_type=bound_type, ) actor_action_type = DiscretePolicyHead( actor_head_hidden_size, action_shape.action_type_shape, actor_head_layer_num, activation=activation, norm_type=norm_type, ) self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) # must use list, not nn.ModuleList if self.share_encoder: self.actor = [self.encoder, self.actor_head] self.critic = [self.encoder, self.critic_head] else: self.actor = [self.actor_encoder, self.actor_head] self.critic = [self.critic_encoder, self.critic_head] # Convenient for calling some apis (e.g. self.critic.parameters()), # but may cause misunderstanding when `print(self)` self.actor = nn.ModuleList(self.actor) self.critic = nn.ModuleList(self.critic) def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: if self.share_encoder: x = self.encoder(x) else: x = self.actor_encoder(x) if self.action_space == 'discrete': return self.actor_head(x) elif self.action_space == 'continuous': x = self.actor_head(x) # mu, sigma return ttorch.as_tensor(x) elif self.action_space == 'hybrid': action_type = self.actor_head[0](x) action_args = self.actor_head[1](x) return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args}) def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: if self.share_encoder: x = self.encoder(x) else: x = self.critic_encoder(x) x = self.critic_head(x) return x def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: if self.share_encoder: actor_embedding = critic_embedding = self.encoder(x) else: actor_embedding = self.actor_encoder(x) critic_embedding = self.critic_encoder(x) value = self.critic_head(critic_embedding) if self.action_space == 'discrete': logit = self.actor_head(actor_embedding) return ttorch.as_tensor({'logit': logit, 'value': value['pred']}) elif self.action_space == 'continuous': x = self.actor_head(actor_embedding) return ttorch.as_tensor({'logit': x, 'value': value['pred']}) elif self.action_space == 'hybrid': action_type = self.actor_head[0](actor_embedding) action_args = self.actor_head[1](actor_embedding) return ttorch.as_tensor( { 'logit': { 'action_type': action_type, 'action_args': action_args }, 'value': value['pred'] } )