|
from typing import Optional, Tuple |
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
from ding.torch_utils import MLP, ResBlock |
|
from ding.utils import MODEL_REGISTRY, SequenceType |
|
|
|
from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork |
|
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean |
|
|
|
|
|
|
|
@MODEL_REGISTRY.register('StochasticMuZeroModel') |
|
class StochasticMuZeroModel(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
observation_shape: SequenceType = (12, 96, 96), |
|
action_space_size: int = 6, |
|
chance_space_size: int = 2, |
|
num_res_blocks: int = 1, |
|
num_channels: int = 64, |
|
reward_head_channels: int = 16, |
|
value_head_channels: int = 16, |
|
policy_head_channels: int = 16, |
|
fc_reward_layers: SequenceType = [32], |
|
fc_value_layers: SequenceType = [32], |
|
fc_policy_layers: SequenceType = [32], |
|
reward_support_size: int = 601, |
|
value_support_size: int = 601, |
|
proj_hid: int = 1024, |
|
proj_out: int = 1024, |
|
pred_hid: int = 512, |
|
pred_out: int = 1024, |
|
self_supervised_learning_loss: bool = False, |
|
categorical_distribution: bool = True, |
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
last_linear_layer_init_zero: bool = True, |
|
state_norm: bool = False, |
|
downsample: bool = False, |
|
*args, |
|
**kwargs |
|
): |
|
""" |
|
Overview: |
|
The definition of the neural network model used in Stochastic MuZero, |
|
which is proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1. |
|
Stochastic MuZero model consists of a representation network, a dynamics network and a prediction network. |
|
The networks are built on convolution residual blocks and fully connected layers. |
|
Arguments: |
|
- observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. |
|
- action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. |
|
- chance_space_size: (:obj:`int`): Chance space size, the action space for decision node, usually an integer |
|
number for discrete action space. |
|
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. |
|
- num_channels (:obj:`int`): The channels of hidden states. |
|
- reward_head_channels (:obj:`int`): The channels of reward head. |
|
- value_head_channels (:obj:`int`): The channels of value head. |
|
- policy_head_channels (:obj:`int`): The channels of policy head. |
|
- fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). |
|
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). |
|
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). |
|
- reward_support_size (:obj:`int`): The size of categorical reward output |
|
- value_support_size (:obj:`int`): The size of categorical value output. |
|
- proj_hid (:obj:`int`): The size of projection hidden layer. |
|
- proj_out (:obj:`int`): The size of projection output layer. |
|
- pred_hid (:obj:`int`): The size of prediction hidden layer. |
|
- pred_out (:obj:`int`): The size of prediction output layer. |
|
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ |
|
in Stochastic MuZero model, default set it to False. |
|
- categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \ |
|
distribution for value and reward. |
|
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ |
|
operation to speedup, e.g. ReLU(inplace=True). |
|
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ |
|
dynamics/prediction mlp, default set it to True. |
|
- state_norm (:obj:`bool`): Whether to use normalization for hidden states, default set it to False. |
|
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ |
|
defaults to True. This option is often used in video games like Atari. In board games like go, \ |
|
we don't need this module. |
|
""" |
|
super(StochasticMuZeroModel, self).__init__() |
|
self.categorical_distribution = categorical_distribution |
|
if self.categorical_distribution: |
|
self.reward_support_size = reward_support_size |
|
self.value_support_size = value_support_size |
|
else: |
|
self.reward_support_size = 1 |
|
self.value_support_size = 1 |
|
|
|
self.action_space_size = action_space_size |
|
self.chance_space_size = chance_space_size |
|
|
|
self.proj_hid = proj_hid |
|
self.proj_out = proj_out |
|
self.pred_hid = pred_hid |
|
self.pred_out = pred_out |
|
self.self_supervised_learning_loss = self_supervised_learning_loss |
|
self.last_linear_layer_init_zero = last_linear_layer_init_zero |
|
self.state_norm = state_norm |
|
self.downsample = downsample |
|
|
|
flatten_output_size_for_reward_head = ( |
|
(reward_head_channels * math.ceil(observation_shape[1] / 16) * |
|
math.ceil(observation_shape[2] / 16)) if downsample else |
|
(reward_head_channels * observation_shape[1] * observation_shape[2]) |
|
) |
|
flatten_output_size_for_value_head = ( |
|
(value_head_channels * math.ceil(observation_shape[1] / 16) * |
|
math.ceil(observation_shape[2] / 16)) if downsample else |
|
(value_head_channels * observation_shape[1] * observation_shape[2]) |
|
) |
|
flatten_output_size_for_policy_head = ( |
|
(policy_head_channels * math.ceil(observation_shape[1] / 16) * |
|
math.ceil(observation_shape[2] / 16)) if downsample else |
|
(policy_head_channels * observation_shape[1] * observation_shape[2]) |
|
) |
|
|
|
self.representation_network = RepresentationNetwork( |
|
observation_shape, |
|
num_res_blocks, |
|
num_channels, |
|
downsample, |
|
) |
|
|
|
self.chance_encoder = ChanceEncoder( |
|
observation_shape, chance_space_size |
|
) |
|
self.dynamics_network = DynamicsNetwork( |
|
num_res_blocks, |
|
num_channels + 1, |
|
reward_head_channels, |
|
fc_reward_layers, |
|
self.reward_support_size, |
|
flatten_output_size_for_reward_head, |
|
last_linear_layer_init_zero=self.last_linear_layer_init_zero, |
|
) |
|
self.prediction_network = PredictionNetwork( |
|
observation_shape, |
|
action_space_size, |
|
num_res_blocks, |
|
num_channels, |
|
value_head_channels, |
|
policy_head_channels, |
|
fc_value_layers, |
|
fc_policy_layers, |
|
self.value_support_size, |
|
flatten_output_size_for_value_head, |
|
flatten_output_size_for_policy_head, |
|
last_linear_layer_init_zero=self.last_linear_layer_init_zero, |
|
) |
|
|
|
self.afterstate_dynamics_network = AfterstateDynamicsNetwork( |
|
num_res_blocks, |
|
num_channels + 1, |
|
reward_head_channels, |
|
fc_reward_layers, |
|
self.reward_support_size, |
|
flatten_output_size_for_reward_head, |
|
last_linear_layer_init_zero=self.last_linear_layer_init_zero, |
|
) |
|
self.afterstate_prediction_network = AfterstatePredictionNetwork( |
|
chance_space_size, |
|
num_res_blocks, |
|
num_channels, |
|
value_head_channels, |
|
policy_head_channels, |
|
fc_value_layers, |
|
fc_policy_layers, |
|
self.value_support_size, |
|
flatten_output_size_for_value_head, |
|
flatten_output_size_for_policy_head, |
|
last_linear_layer_init_zero=self.last_linear_layer_init_zero, |
|
) |
|
|
|
if self.self_supervised_learning_loss: |
|
|
|
if self.downsample: |
|
|
|
|
|
|
|
|
|
ceil_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) |
|
self.projection_input_dim = num_channels * ceil_size |
|
else: |
|
self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] |
|
|
|
self.projection = nn.Sequential( |
|
nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, |
|
nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, |
|
nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) |
|
) |
|
self.prediction_head = nn.Sequential( |
|
nn.Linear(self.proj_out, self.pred_hid), |
|
nn.BatchNorm1d(self.pred_hid), |
|
activation, |
|
nn.Linear(self.pred_hid, self.pred_out), |
|
) |
|
|
|
def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: |
|
""" |
|
Overview: |
|
Initial inference of Stochastic MuZero model, which is the first step of the Stochastic MuZero model. |
|
To perform the initial inference, we first use the representation network to obtain the ``latent_state``. |
|
Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): The 2D image observation data. |
|
Returns (MZNetworkOutput): |
|
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. |
|
- reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ |
|
In initial inference, we set it to zero vector. |
|
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. |
|
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. |
|
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. |
|
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
""" |
|
batch_size = obs.size(0) |
|
latent_state = self._representation(obs) |
|
policy_logits, value = self._prediction(latent_state) |
|
return MZNetworkOutput( |
|
value, |
|
[0. for _ in range(batch_size)], |
|
policy_logits, |
|
latent_state, |
|
) |
|
|
|
def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, |
|
afterstate: bool = False) -> MZNetworkOutput: |
|
""" |
|
Overview: |
|
Recurrent inference of Stochastic MuZero model, which is the rollout step of the Stochastic MuZero model. |
|
To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, |
|
``reward``, by the given current ``latent_state`` and ``action``. |
|
We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current |
|
``latent_state``. |
|
Arguments: |
|
- state (:obj:`torch.Tensor`): The encoding latent state of input state or the afterstate. |
|
- option (:obj:`torch.Tensor`): The action to rollout or the chance to predict next latent state. |
|
- afterstate (:obj:`bool`): Whether to use afterstate prediction network to predict next latent state. |
|
Returns (MZNetworkOutput): |
|
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. |
|
- reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. |
|
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
- next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. |
|
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. |
|
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. |
|
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. |
|
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
""" |
|
|
|
if afterstate: |
|
|
|
next_latent_state, reward = self._dynamics(state, option) |
|
policy_logits, value = self._prediction(next_latent_state) |
|
return MZNetworkOutput(value, reward, policy_logits, next_latent_state) |
|
else: |
|
|
|
next_afterstate, reward = self._afterstate_dynamics(state, option) |
|
policy_logits, value = self._afterstate_prediction(next_afterstate) |
|
return MZNetworkOutput(value, reward, policy_logits, next_afterstate) |
|
|
|
def _representation(self, observation: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Use the representation network to encode the observations into latent state. |
|
Arguments: |
|
- obs (:obj:`torch.Tensor`): The 2D image observation data. |
|
Returns: |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
Shapes: |
|
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
""" |
|
latent_state = self.representation_network(observation) |
|
if self.state_norm: |
|
latent_state = renormalize(latent_state) |
|
return latent_state |
|
|
|
def chance_encode(self, observation: torch.Tensor): |
|
output = self.chance_encoder(observation) |
|
return output |
|
|
|
def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Use the prediction network to predict ``policy_logits`` and ``value``. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
Returns: |
|
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. |
|
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. |
|
Shapes: |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. |
|
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. |
|
""" |
|
return self.prediction_network(latent_state) |
|
|
|
def _afterstate_prediction(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Use the prediction network to predict ``policy_logits`` and ``value``. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
Returns: |
|
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. |
|
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. |
|
Shapes: |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. |
|
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. |
|
""" |
|
return self.afterstate_prediction_network(afterstate) |
|
|
|
def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` |
|
and ``reward``. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
- action (:obj:`torch.Tensor`): The predicted action to rollout. |
|
Returns: |
|
- next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. |
|
- reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. |
|
Shapes: |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. |
|
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. |
|
""" |
|
|
|
|
|
|
|
|
|
action_encoding = ( |
|
torch.ones(( |
|
latent_state.shape[0], |
|
1, |
|
latent_state.shape[2], |
|
latent_state.shape[3], |
|
)).to(action.device).float() |
|
) |
|
if len(action.shape) == 2: |
|
|
|
|
|
action = action.unsqueeze(-1) |
|
elif len(action.shape) == 1: |
|
|
|
|
|
action = action.unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
action_encoding = (action[:, 0, None, None] * action_encoding / self.chance_space_size) |
|
|
|
|
|
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) |
|
|
|
next_latent_state, reward = self.dynamics_network(state_action_encoding) |
|
if self.state_norm: |
|
next_latent_state = renormalize(next_latent_state) |
|
return next_latent_state, reward |
|
|
|
def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[ |
|
torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` |
|
and ``reward``. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
- action (:obj:`torch.Tensor`): The predicted action to rollout. |
|
Returns: |
|
- next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. |
|
- reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. |
|
Shapes: |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. |
|
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. |
|
""" |
|
|
|
|
|
|
|
|
|
action_encoding = ( |
|
torch.ones(( |
|
latent_state.shape[0], |
|
1, |
|
latent_state.shape[2], |
|
latent_state.shape[3], |
|
)).to(action.device).float() |
|
) |
|
if len(action.shape) == 2: |
|
|
|
|
|
action = action.unsqueeze(-1) |
|
elif len(action.shape) == 1: |
|
|
|
|
|
action = action.unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
action_encoding = (action[:, 0, None, None] * action_encoding / self.action_space_size) |
|
|
|
|
|
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) |
|
|
|
next_latent_state, reward = self.afterstate_dynamics_network(state_action_encoding) |
|
if self.state_norm: |
|
next_latent_state = renormalize(next_latent_state) |
|
return next_latent_state, reward |
|
|
|
def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Project the latent state to a lower dimension to calculate the self-supervised loss, which is involved in |
|
in EfficientZero. |
|
For more details, please refer to paper ``Exploring Simple Siamese Representation Learning``. |
|
Arguments: |
|
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. |
|
- with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. |
|
Returns: |
|
- proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. |
|
Shapes: |
|
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ |
|
latent state, W_ is the width of latent state. |
|
- proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. |
|
|
|
Examples: |
|
>>> latent_state = torch.randn(256, 64, 6, 6) |
|
>>> output = self.project(latent_state) |
|
>>> output.shape # (256, 1024) |
|
|
|
.. note:: |
|
for Atari: |
|
observation_shape = (12, 96, 96), # original shape is (3,96,96), frame_stack_num=4 |
|
if downsample is True, latent_state.shape: (batch_size, num_channel, obs_shape[1] / 16, obs_shape[2] / 16) |
|
i.e., (256, 64, 96 / 16, 96 / 16) = (256, 64, 6, 6) |
|
latent_state reshape: (256, 64, 6, 6) -> (256,64*6*6) = (256, 2304) |
|
# self.projection_input_dim = 64*6*6 = 2304 |
|
# self.projection_output_dim = 1024 |
|
""" |
|
latent_state = latent_state.reshape(latent_state.shape[0], -1) |
|
proj = self.projection(latent_state) |
|
|
|
if with_grad: |
|
|
|
return self.prediction_head(proj) |
|
else: |
|
return proj.detach() |
|
|
|
def get_params_mean(self) -> float: |
|
return get_params_mean(self) |
|
|
|
|
|
class DynamicsNetwork(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
num_res_blocks: int, |
|
num_channels: int, |
|
reward_head_channels: int, |
|
fc_reward_layers: SequenceType, |
|
output_support_size: int, |
|
flatten_output_size_for_reward_head: int, |
|
last_linear_layer_init_zero: bool = True, |
|
activation: Optional[nn.Module] = nn.ReLU(inplace=True), |
|
): |
|
""" |
|
Overview: |
|
The definition of dynamics network in Stochastic MuZero algorithm, which is used to predict next latent state and |
|
reward given current latent state and action. |
|
Arguments: |
|
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. |
|
- num_channels (:obj:`int`): The channels of input, including obs and action encoding. |
|
- reward_head_channels (:obj:`int`): The channels of reward head. |
|
- fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). |
|
- output_support_size (:obj:`int`): The size of categorical reward output. |
|
- flatten_output_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ |
|
the input size of reward head. |
|
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ |
|
reward mlp, default set it to True. |
|
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ |
|
operation to speedup, e.g. ReLU(inplace=True). |
|
""" |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head |
|
|
|
self.conv = nn.Conv2d(num_channels, num_channels - 1, kernel_size=3, stride=1, padding=1, bias=False) |
|
self.bn = nn.BatchNorm2d(num_channels - 1) |
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResBlock( |
|
in_channels=num_channels - 1, activation=activation, norm_type='BN', res_type='basic', bias=False |
|
) for _ in range(num_res_blocks) |
|
] |
|
) |
|
|
|
self.conv1x1_reward = nn.Conv2d(num_channels - 1, reward_head_channels, 1) |
|
self.bn_reward = nn.BatchNorm2d(reward_head_channels) |
|
self.fc_reward_head = MLP( |
|
self.flatten_output_size_for_reward_head, |
|
hidden_channels=fc_reward_layers[0], |
|
layer_num=len(fc_reward_layers) + 1, |
|
out_channels=output_support_size, |
|
activation=activation, |
|
norm_type='BN', |
|
output_activation=False, |
|
output_norm=False, |
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
self.activation = activation |
|
|
|
def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Forward computation of the dynamics network. Predict next latent state given current latent state and action. |
|
Arguments: |
|
- state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ |
|
latent state and action encoding, with shape (batch_size, num_channels, height, width). |
|
Returns: |
|
- next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ |
|
height, width). |
|
- reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). |
|
""" |
|
|
|
latent_state = state_action_encoding[:, :-1, :, :] |
|
x = self.conv(state_action_encoding) |
|
x = self.bn(x) |
|
|
|
|
|
x += latent_state |
|
x = self.activation(x) |
|
|
|
for block in self.resblocks: |
|
x = block(x) |
|
next_latent_state = x |
|
|
|
x = self.conv1x1_reward(next_latent_state) |
|
x = self.bn_reward(x) |
|
x = self.activation(x) |
|
x = x.view(-1, self.flatten_output_size_for_reward_head) |
|
|
|
|
|
reward = self.fc_reward_head(x) |
|
|
|
return next_latent_state, reward |
|
|
|
def get_dynamic_mean(self) -> float: |
|
return get_dynamic_mean(self) |
|
|
|
def get_reward_mean(self) -> float: |
|
return get_reward_mean(self) |
|
|
|
|
|
|
|
AfterstateDynamicsNetwork = DynamicsNetwork |
|
|
|
|
|
class AfterstatePredictionNetwork(nn.Module): |
|
def __init__( |
|
self, |
|
action_space_size: int, |
|
num_res_blocks: int, |
|
num_channels: int, |
|
value_head_channels: int, |
|
policy_head_channels: int, |
|
fc_value_layers: int, |
|
fc_policy_layers: int, |
|
output_support_size: int, |
|
flatten_output_size_for_value_head: int, |
|
flatten_output_size_for_policy_head: int, |
|
last_linear_layer_init_zero: bool = True, |
|
activation: nn.Module = nn.ReLU(inplace=True), |
|
) -> None: |
|
""" |
|
Overview: |
|
The definition of afterstate policy and value prediction network, which is used to predict value and policy by the |
|
given afterstate. |
|
Arguments: |
|
- action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. |
|
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. |
|
- num_channels (:obj:`int`): The channels of hidden states. |
|
- value_head_channels (:obj:`int`): The channels of value head. |
|
- policy_head_channels (:obj:`int`): The channels of policy head. |
|
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). |
|
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). |
|
- output_support_size (:obj:`int`): The size of categorical value output. |
|
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ |
|
- flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ |
|
of the value head. |
|
- flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ |
|
of the policy head. |
|
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ |
|
dynamics/prediction mlp, default set it to True. |
|
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ |
|
operation to speedup, e.g. ReLU(inplace=True). |
|
""" |
|
super(AfterstatePredictionNetwork, self).__init__() |
|
self.resblocks = nn.ModuleList( |
|
[ |
|
ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) |
|
for _ in range(num_res_blocks) |
|
] |
|
) |
|
self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) |
|
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) |
|
self.bn_value = nn.BatchNorm2d(value_head_channels) |
|
self.bn_policy = nn.BatchNorm2d(policy_head_channels) |
|
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head |
|
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head |
|
self.activation = activation |
|
|
|
self.fc_value = MLP( |
|
in_channels=self.flatten_output_size_for_value_head, |
|
hidden_channels=fc_value_layers[0], |
|
out_channels=output_support_size, |
|
layer_num=len(fc_value_layers) + 1, |
|
activation=self.activation, |
|
norm_type='BN', |
|
output_activation=False, |
|
output_norm=False, |
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
self.fc_policy = MLP( |
|
in_channels=self.flatten_output_size_for_policy_head, |
|
hidden_channels=fc_policy_layers[0], |
|
out_channels=action_space_size, |
|
layer_num=len(fc_policy_layers) + 1, |
|
activation=self.activation, |
|
norm_type='BN', |
|
output_activation=False, |
|
output_norm=False, |
|
last_linear_layer_init_zero=last_linear_layer_init_zero |
|
) |
|
|
|
def forward(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Forward computation of the afterstate prediction network. |
|
Arguments: |
|
- afterstate (:obj:`torch.Tensor`): input tensor with shape (B, afterstate_dim). |
|
Returns: |
|
- afterstate_policy_logits (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). |
|
- afterstate_value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). |
|
""" |
|
for res_block in self.resblocks: |
|
afterstate = res_block(afterstate) |
|
|
|
value = self.conv1x1_value(afterstate) |
|
value = self.bn_value(value) |
|
value = self.activation(value) |
|
|
|
policy = self.conv1x1_policy(afterstate) |
|
policy = self.bn_policy(policy) |
|
policy = self.activation(policy) |
|
|
|
value = value.reshape(-1, self.flatten_output_size_for_value_head) |
|
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) |
|
|
|
afterstate_value = self.fc_value(value) |
|
afterstate_policy_logits = self.fc_policy(policy) |
|
return afterstate_policy_logits, afterstate_value |
|
|
|
|
|
class ChanceEncoderBackbone(nn.Module): |
|
""" |
|
Overview: |
|
The definition of chance encoder backbone network, \ |
|
which is used to encode the (image) observation into a latent space. |
|
Arguments: |
|
- input_dimensions (:obj:`tuple`): The dimension of observation space. |
|
- chance_encoding_dim (:obj:`int`): The dimension of chance encoding. |
|
""" |
|
|
|
def __init__(self, input_dimensions, chance_encoding_dim=4): |
|
super(ChanceEncoderBackbone, self).__init__() |
|
self.conv1 = nn.Conv2d(input_dimensions[0] * 2, 32, 3, padding=1) |
|
self.conv2 = nn.Conv2d(32, 64, 3, padding=1) |
|
self.fc1 = nn.Linear(64 * input_dimensions[1] * input_dimensions[2], 128) |
|
self.fc2 = nn.Linear(128, 64) |
|
self.fc3 = nn.Linear(64, chance_encoding_dim) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.conv1(x)) |
|
x = torch.relu(self.conv2(x)) |
|
x = x.view(x.shape[0], -1) |
|
x = torch.relu(self.fc1(x)) |
|
x = torch.relu(self.fc2(x)) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
class ChanceEncoderBackboneMLP(nn.Module): |
|
""" |
|
Overview: |
|
The definition of chance encoder backbone network, \ |
|
which is used to encode the (vector) observation into a latent space. |
|
Arguments: |
|
- input_dimensions (:obj:`tuple`): The dimension of observation space. |
|
- chance_encoding_dim (:obj:`int`): The dimension of chance encoding. |
|
""" |
|
|
|
def __init__(self, input_dimensions, chance_encoding_dim=4): |
|
super(ChanceEncoderBackboneMLP, self).__init__() |
|
self.fc1 = nn.Linear(input_dimensions, 128) |
|
self.fc2 = nn.Linear(128, 64) |
|
self.fc3 = nn.Linear(64, chance_encoding_dim) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.fc1(x)) |
|
x = torch.relu(self.fc2(x)) |
|
x = self.fc3(x) |
|
return x |
|
|
|
|
|
class ChanceEncoder(nn.Module): |
|
|
|
def __init__(self, input_dimensions, action_dimension, encoder_backbone_type='conv'): |
|
super().__init__() |
|
|
|
self.action_space = action_dimension |
|
if encoder_backbone_type == 'conv': |
|
|
|
self.encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) |
|
elif encoder_backbone_type == 'mlp': |
|
self.encoder = ChanceEncoderBackboneMLP(input_dimensions, action_dimension) |
|
else: |
|
raise ValueError('Encoder backbone type not supported') |
|
|
|
|
|
self.onehot_argmax = StraightThroughEstimator() |
|
|
|
def forward(self, observations): |
|
""" |
|
Overview: |
|
Forward method for the ChanceEncoder. This method takes an observation \ |
|
and applies the encoder to transform it to a latent space. Then applies the \ |
|
StraightThroughEstimator to this encoding. \ |
|
|
|
References: Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, |
|
Chance Outcomes section. |
|
Arguments: |
|
- observations (:obj:`torch.Tensor`): Observation tensor. |
|
Returns: |
|
- chance (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax. |
|
- chance_encoding (:obj:`torch.Tensor`): Encoding of the input observation tensor. |
|
""" |
|
|
|
chance_encoding = self.encoder(observations) |
|
|
|
chance_onehot = self.onehot_argmax(chance_encoding) |
|
return chance_encoding, chance_onehot |
|
|
|
|
|
class StraightThroughEstimator(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
""" |
|
Overview: |
|
Forward method for the StraightThroughEstimator. This applies the one-hot argmax \ |
|
function to the input tensor. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor. |
|
Returns: |
|
- (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax. |
|
""" |
|
|
|
x = OnehotArgmax.apply(x) |
|
return x |
|
|
|
|
|
class OnehotArgmax(torch.autograd.Function): |
|
""" |
|
Overview: |
|
Custom PyTorch function for one-hot argmax. This function transforms the input tensor \ |
|
into a one-hot tensor where the index with the maximum value in the original tensor is \ |
|
set to 1 and all other indices are set to 0. It allows gradients to flow to the encoder \ |
|
during backpropagation. |
|
|
|
For more information, refer to: \ |
|
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, input): |
|
""" |
|
Overview: |
|
Forward method for the one-hot argmax function. This method transforms the input \ |
|
tensor into a one-hot tensor. |
|
Arguments: |
|
- ctx (:obj:`context`): A context object that can be used to stash information for |
|
backward computation. |
|
- input (:obj:`torch.Tensor`): Input tensor. |
|
Returns: |
|
- (:obj:`torch.Tensor`): One-hot tensor. |
|
""" |
|
|
|
return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1, keepdim=True), 1.) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
""" |
|
Overview: |
|
Backward method for the one-hot argmax function. This method allows gradients \ |
|
to flow to the encoder during backpropagation. |
|
Arguments: |
|
- ctx (:obj:`context`): A context object that was stashed in the forward pass. |
|
- grad_output (:obj:`torch.Tensor`): The gradient of the output tensor. |
|
Returns: |
|
- (:obj:`torch.Tensor`): The gradient of the input tensor. |
|
""" |
|
return grad_output |
|
|