gomoku / LightZero /lzero /model /stochastic_muzero_model.py
zjowowen's picture
init space
079c32c
raw
history blame
42 kB
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
from ding.torch_utils import MLP, ResBlock
from ding.utils import MODEL_REGISTRY, SequenceType
from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean
# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document.
@MODEL_REGISTRY.register('StochasticMuZeroModel')
class StochasticMuZeroModel(nn.Module):
def __init__(
self,
observation_shape: SequenceType = (12, 96, 96),
action_space_size: int = 6,
chance_space_size: int = 2,
num_res_blocks: int = 1,
num_channels: int = 64,
reward_head_channels: int = 16,
value_head_channels: int = 16,
policy_head_channels: int = 16,
fc_reward_layers: SequenceType = [32],
fc_value_layers: SequenceType = [32],
fc_policy_layers: SequenceType = [32],
reward_support_size: int = 601,
value_support_size: int = 601,
proj_hid: int = 1024,
proj_out: int = 1024,
pred_hid: int = 512,
pred_out: int = 1024,
self_supervised_learning_loss: bool = False,
categorical_distribution: bool = True,
activation: nn.Module = nn.ReLU(inplace=True),
last_linear_layer_init_zero: bool = True,
state_norm: bool = False,
downsample: bool = False,
*args,
**kwargs
):
"""
Overview:
The definition of the neural network model used in Stochastic MuZero,
which is proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1.
Stochastic MuZero model consists of a representation network, a dynamics network and a prediction network.
The networks are built on convolution residual blocks and fully connected layers.
Arguments:
- observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari.
- action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
- chance_space_size: (:obj:`int`): Chance space size, the action space for decision node, usually an integer
number for discrete action space.
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
- num_channels (:obj:`int`): The channels of hidden states.
- reward_head_channels (:obj:`int`): The channels of reward head.
- value_head_channels (:obj:`int`): The channels of value head.
- policy_head_channels (:obj:`int`): The channels of policy head.
- fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head).
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- reward_support_size (:obj:`int`): The size of categorical reward output
- value_support_size (:obj:`int`): The size of categorical value output.
- proj_hid (:obj:`int`): The size of projection hidden layer.
- proj_out (:obj:`int`): The size of projection output layer.
- pred_hid (:obj:`int`): The size of prediction hidden layer.
- pred_out (:obj:`int`): The size of prediction output layer.
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
in Stochastic MuZero model, default set it to False.
- categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \
distribution for value and reward.
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
operation to speedup, e.g. ReLU(inplace=True).
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \
dynamics/prediction mlp, default set it to True.
- state_norm (:obj:`bool`): Whether to use normalization for hidden states, default set it to False.
- downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
defaults to True. This option is often used in video games like Atari. In board games like go, \
we don't need this module.
"""
super(StochasticMuZeroModel, self).__init__()
self.categorical_distribution = categorical_distribution
if self.categorical_distribution:
self.reward_support_size = reward_support_size
self.value_support_size = value_support_size
else:
self.reward_support_size = 1
self.value_support_size = 1
self.action_space_size = action_space_size
self.chance_space_size = chance_space_size
self.proj_hid = proj_hid
self.proj_out = proj_out
self.pred_hid = pred_hid
self.pred_out = pred_out
self.self_supervised_learning_loss = self_supervised_learning_loss
self.last_linear_layer_init_zero = last_linear_layer_init_zero
self.state_norm = state_norm
self.downsample = downsample
flatten_output_size_for_reward_head = (
(reward_head_channels * math.ceil(observation_shape[1] / 16) *
math.ceil(observation_shape[2] / 16)) if downsample else
(reward_head_channels * observation_shape[1] * observation_shape[2])
)
flatten_output_size_for_value_head = (
(value_head_channels * math.ceil(observation_shape[1] / 16) *
math.ceil(observation_shape[2] / 16)) if downsample else
(value_head_channels * observation_shape[1] * observation_shape[2])
)
flatten_output_size_for_policy_head = (
(policy_head_channels * math.ceil(observation_shape[1] / 16) *
math.ceil(observation_shape[2] / 16)) if downsample else
(policy_head_channels * observation_shape[1] * observation_shape[2])
)
self.representation_network = RepresentationNetwork(
observation_shape,
num_res_blocks,
num_channels,
downsample,
)
self.chance_encoder = ChanceEncoder(
observation_shape, chance_space_size
)
self.dynamics_network = DynamicsNetwork(
num_res_blocks,
num_channels + 1,
reward_head_channels,
fc_reward_layers,
self.reward_support_size,
flatten_output_size_for_reward_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
)
self.prediction_network = PredictionNetwork(
observation_shape,
action_space_size,
num_res_blocks,
num_channels,
value_head_channels,
policy_head_channels,
fc_value_layers,
fc_policy_layers,
self.value_support_size,
flatten_output_size_for_value_head,
flatten_output_size_for_policy_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
)
self.afterstate_dynamics_network = AfterstateDynamicsNetwork(
num_res_blocks,
num_channels + 1,
reward_head_channels,
fc_reward_layers,
self.reward_support_size,
flatten_output_size_for_reward_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
)
self.afterstate_prediction_network = AfterstatePredictionNetwork(
chance_space_size,
num_res_blocks,
num_channels,
value_head_channels,
policy_head_channels,
fc_value_layers,
fc_policy_layers,
self.value_support_size,
flatten_output_size_for_value_head,
flatten_output_size_for_policy_head,
last_linear_layer_init_zero=self.last_linear_layer_init_zero,
)
if self.self_supervised_learning_loss:
# projection used in EfficientZero
if self.downsample:
# In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of
# (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is
# (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus,
# self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304
ceil_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16)
self.projection_input_dim = num_channels * ceil_size
else:
self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2]
self.projection = nn.Sequential(
nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation,
nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation,
nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out)
)
self.prediction_head = nn.Sequential(
nn.Linear(self.proj_out, self.pred_hid),
nn.BatchNorm1d(self.pred_hid),
activation,
nn.Linear(self.pred_hid, self.pred_out),
)
def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput:
"""
Overview:
Initial inference of Stochastic MuZero model, which is the first step of the Stochastic MuZero model.
To perform the initial inference, we first use the representation network to obtain the ``latent_state``.
Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``.
Arguments:
- obs (:obj:`torch.Tensor`): The 2D image observation data.
Returns (MZNetworkOutput):
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
- reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \
In initial inference, we set it to zero vector.
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
"""
batch_size = obs.size(0)
latent_state = self._representation(obs)
policy_logits, value = self._prediction(latent_state)
return MZNetworkOutput(
value,
[0. for _ in range(batch_size)],
policy_logits,
latent_state,
)
def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor,
afterstate: bool = False) -> MZNetworkOutput:
"""
Overview:
Recurrent inference of Stochastic MuZero model, which is the rollout step of the Stochastic MuZero model.
To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``,
``reward``, by the given current ``latent_state`` and ``action``.
We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current
``latent_state``.
Arguments:
- state (:obj:`torch.Tensor`): The encoding latent state of input state or the afterstate.
- option (:obj:`torch.Tensor`): The action to rollout or the chance to predict next latent state.
- afterstate (:obj:`bool`): Whether to use afterstate prediction network to predict next latent state.
Returns (MZNetworkOutput):
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
- reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action.
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- next_latent_state (:obj:`torch.Tensor`): The predicted next latent state.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size.
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
"""
if afterstate:
# state is afterstate, option is chance
next_latent_state, reward = self._dynamics(state, option)
policy_logits, value = self._prediction(next_latent_state)
return MZNetworkOutput(value, reward, policy_logits, next_latent_state)
else:
# state is latent_state, option is action
next_afterstate, reward = self._afterstate_dynamics(state, option)
policy_logits, value = self._afterstate_prediction(next_afterstate)
return MZNetworkOutput(value, reward, policy_logits, next_afterstate)
def _representation(self, observation: torch.Tensor) -> torch.Tensor:
"""
Overview:
Use the representation network to encode the observations into latent state.
Arguments:
- obs (:obj:`torch.Tensor`): The 2D image observation data.
Returns:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
Shapes:
- obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
"""
latent_state = self.representation_network(observation)
if self.state_norm:
latent_state = renormalize(latent_state)
return latent_state
def chance_encode(self, observation: torch.Tensor):
output = self.chance_encoder(observation)
return output
def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Use the prediction network to predict ``policy_logits`` and ``value``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
Returns:
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
"""
return self.prediction_network(latent_state)
def _afterstate_prediction(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Use the prediction network to predict ``policy_logits`` and ``value``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
Returns:
- policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action.
- value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size.
"""
return self.afterstate_prediction_network(afterstate)
def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state``
and ``reward``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- action (:obj:`torch.Tensor`): The predicted action to rollout.
Returns:
- next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep.
- reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size.
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
"""
# NOTE: the discrete action encoding type is important for some environments
# discrete action space
# the final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1).
action_encoding = (
torch.ones((
latent_state.shape[0],
1,
latent_state.shape[2],
latent_state.shape[3],
)).to(action.device).float()
)
if len(action.shape) == 2:
# (batch_size, action_dim) -> (batch_size, action_dim, 1)
# e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1])
action = action.unsqueeze(-1)
elif len(action.shape) == 1:
# (batch_size,) -> (batch_size, action_dim=1, 1)
# e.g., -> torch.Size([8, 1]) -> torch.Size([8, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1)
# action[:, 0, None, None] shape: (batch_size, action_dim, 1, 1) e.g. (8, 1, 1, 1)
# the final action_encoding shape: (batch_size, 1, latent_state[2], latent_state[3]) e.g. (8, 1, 4, 1),
# where each element is normalized as action[i]/action_space_size
action_encoding = (action[:, 0, None, None] * action_encoding / self.chance_space_size)
# state_action_encoding shape: (batch_size, latent_state[1] + 1, latent_state[2], latent_state[3])
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1)
next_latent_state, reward = self.dynamics_network(state_action_encoding)
if self.state_norm:
next_latent_state = renormalize(next_latent_state)
return next_latent_state, reward
def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[
torch.Tensor, torch.Tensor]:
"""
Overview:
Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state``
and ``reward``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- action (:obj:`torch.Tensor`): The predicted action to rollout.
Returns:
- next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep.
- reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size.
- next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size.
"""
# NOTE: the discrete action encoding type is important for some environments
# discrete action space
# the final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1).
action_encoding = (
torch.ones((
latent_state.shape[0],
1,
latent_state.shape[2],
latent_state.shape[3],
)).to(action.device).float()
)
if len(action.shape) == 2:
# (batch_size, action_dim) -> (batch_size, action_dim, 1)
# e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1])
action = action.unsqueeze(-1)
elif len(action.shape) == 1:
# (batch_size,) -> (batch_size, action_dim=1, 1)
# e.g., -> torch.Size([8, 1]) -> torch.Size([8, 1, 1])
action = action.unsqueeze(-1).unsqueeze(-1)
# action[:, 0, None, None] shape: (batch_size, action_dim, 1, 1) e.g. (8, 1, 1, 1)
# the final action_encoding shape: (batch_size, 1, latent_state[2], latent_state[3]) e.g. (8, 1, 4, 1),
# where each element is normalized as action[i]/action_space_size
action_encoding = (action[:, 0, None, None] * action_encoding / self.action_space_size)
# state_action_encoding shape: (batch_size, latent_state[1] + 1, latent_state[2], latent_state[3])
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1)
next_latent_state, reward = self.afterstate_dynamics_network(state_action_encoding)
if self.state_norm:
next_latent_state = renormalize(next_latent_state)
return next_latent_state, reward
def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor:
"""
Overview:
Project the latent state to a lower dimension to calculate the self-supervised loss, which is involved in
in EfficientZero.
For more details, please refer to paper ``Exploring Simple Siamese Representation Learning``.
Arguments:
- latent_state (:obj:`torch.Tensor`): The encoding latent state of input state.
- with_grad (:obj:`bool`): Whether to calculate gradient for the projection result.
Returns:
- proj (:obj:`torch.Tensor`): The result embedding vector of projection operation.
Shapes:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \
latent state, W_ is the width of latent state.
- proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size.
Examples:
>>> latent_state = torch.randn(256, 64, 6, 6)
>>> output = self.project(latent_state)
>>> output.shape # (256, 1024)
.. note::
for Atari:
observation_shape = (12, 96, 96), # original shape is (3,96,96), frame_stack_num=4
if downsample is True, latent_state.shape: (batch_size, num_channel, obs_shape[1] / 16, obs_shape[2] / 16)
i.e., (256, 64, 96 / 16, 96 / 16) = (256, 64, 6, 6)
latent_state reshape: (256, 64, 6, 6) -> (256,64*6*6) = (256, 2304)
# self.projection_input_dim = 64*6*6 = 2304
# self.projection_output_dim = 1024
"""
latent_state = latent_state.reshape(latent_state.shape[0], -1)
proj = self.projection(latent_state)
if with_grad:
# with grad, use prediction_head
return self.prediction_head(proj)
else:
return proj.detach()
def get_params_mean(self) -> float:
return get_params_mean(self)
class DynamicsNetwork(nn.Module):
def __init__(
self,
num_res_blocks: int,
num_channels: int,
reward_head_channels: int,
fc_reward_layers: SequenceType,
output_support_size: int,
flatten_output_size_for_reward_head: int,
last_linear_layer_init_zero: bool = True,
activation: Optional[nn.Module] = nn.ReLU(inplace=True),
):
"""
Overview:
The definition of dynamics network in Stochastic MuZero algorithm, which is used to predict next latent state and
reward given current latent state and action.
Arguments:
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
- num_channels (:obj:`int`): The channels of input, including obs and action encoding.
- reward_head_channels (:obj:`int`): The channels of reward head.
- fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head).
- output_support_size (:obj:`int`): The size of categorical reward output.
- flatten_output_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \
the input size of reward head.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \
reward mlp, default set it to True.
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
operation to speedup, e.g. ReLU(inplace=True).
"""
super().__init__()
self.num_channels = num_channels
self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head
self.conv = nn.Conv2d(num_channels, num_channels - 1, kernel_size=3, stride=1, padding=1, bias=False)
self.bn = nn.BatchNorm2d(num_channels - 1)
self.resblocks = nn.ModuleList(
[
ResBlock(
in_channels=num_channels - 1, activation=activation, norm_type='BN', res_type='basic', bias=False
) for _ in range(num_res_blocks)
]
)
self.conv1x1_reward = nn.Conv2d(num_channels - 1, reward_head_channels, 1)
self.bn_reward = nn.BatchNorm2d(reward_head_channels)
self.fc_reward_head = MLP(
self.flatten_output_size_for_reward_head,
hidden_channels=fc_reward_layers[0],
layer_num=len(fc_reward_layers) + 1,
out_channels=output_support_size,
activation=activation,
norm_type='BN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
self.activation = activation
def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation of the dynamics network. Predict next latent state given current latent state and action.
Arguments:
- state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \
latent state and action encoding, with shape (batch_size, num_channels, height, width).
Returns:
- next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \
height, width).
- reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size).
"""
# take the state encoding (latent_state), state_action_encoding[:, -1, :, :] is action encoding
latent_state = state_action_encoding[:, :-1, :, :]
x = self.conv(state_action_encoding)
x = self.bn(x)
# the residual link: add state encoding to the state_action encoding
x += latent_state
x = self.activation(x)
for block in self.resblocks:
x = block(x)
next_latent_state = x
x = self.conv1x1_reward(next_latent_state)
x = self.bn_reward(x)
x = self.activation(x)
x = x.view(-1, self.flatten_output_size_for_reward_head)
# use the fully connected layer to predict reward
reward = self.fc_reward_head(x)
return next_latent_state, reward
def get_dynamic_mean(self) -> float:
return get_dynamic_mean(self)
def get_reward_mean(self) -> float:
return get_reward_mean(self)
# TODO(pu): customize different afterstate dynamics network
AfterstateDynamicsNetwork = DynamicsNetwork
class AfterstatePredictionNetwork(nn.Module):
def __init__(
self,
action_space_size: int,
num_res_blocks: int,
num_channels: int,
value_head_channels: int,
policy_head_channels: int,
fc_value_layers: int,
fc_policy_layers: int,
output_support_size: int,
flatten_output_size_for_value_head: int,
flatten_output_size_for_policy_head: int,
last_linear_layer_init_zero: bool = True,
activation: nn.Module = nn.ReLU(inplace=True),
) -> None:
"""
Overview:
The definition of afterstate policy and value prediction network, which is used to predict value and policy by the
given afterstate.
Arguments:
- action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
- num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
- num_channels (:obj:`int`): The channels of hidden states.
- value_head_channels (:obj:`int`): The channels of value head.
- policy_head_channels (:obj:`int`): The channels of policy head.
- fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
- fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
- output_support_size (:obj:`int`): The size of categorical value output.
- self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
- flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
of the value head.
- flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
of the policy head.
- last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \
dynamics/prediction mlp, default set it to True.
- activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
operation to speedup, e.g. ReLU(inplace=True).
"""
super(AfterstatePredictionNetwork, self).__init__()
self.resblocks = nn.ModuleList(
[
ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False)
for _ in range(num_res_blocks)
]
)
self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
self.bn_value = nn.BatchNorm2d(value_head_channels)
self.bn_policy = nn.BatchNorm2d(policy_head_channels)
self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
self.activation = activation
self.fc_value = MLP(
in_channels=self.flatten_output_size_for_value_head,
hidden_channels=fc_value_layers[0],
out_channels=output_support_size,
layer_num=len(fc_value_layers) + 1,
activation=self.activation,
norm_type='BN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
self.fc_policy = MLP(
in_channels=self.flatten_output_size_for_policy_head,
hidden_channels=fc_policy_layers[0],
out_channels=action_space_size,
layer_num=len(fc_policy_layers) + 1,
activation=self.activation,
norm_type='BN',
output_activation=False,
output_norm=False,
last_linear_layer_init_zero=last_linear_layer_init_zero
)
def forward(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Forward computation of the afterstate prediction network.
Arguments:
- afterstate (:obj:`torch.Tensor`): input tensor with shape (B, afterstate_dim).
Returns:
- afterstate_policy_logits (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
- afterstate_value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
"""
for res_block in self.resblocks:
afterstate = res_block(afterstate)
value = self.conv1x1_value(afterstate)
value = self.bn_value(value)
value = self.activation(value)
policy = self.conv1x1_policy(afterstate)
policy = self.bn_policy(policy)
policy = self.activation(policy)
value = value.reshape(-1, self.flatten_output_size_for_value_head)
policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
afterstate_value = self.fc_value(value)
afterstate_policy_logits = self.fc_policy(policy)
return afterstate_policy_logits, afterstate_value
class ChanceEncoderBackbone(nn.Module):
"""
Overview:
The definition of chance encoder backbone network, \
which is used to encode the (image) observation into a latent space.
Arguments:
- input_dimensions (:obj:`tuple`): The dimension of observation space.
- chance_encoding_dim (:obj:`int`): The dimension of chance encoding.
"""
def __init__(self, input_dimensions, chance_encoding_dim=4):
super(ChanceEncoderBackbone, self).__init__()
self.conv1 = nn.Conv2d(input_dimensions[0] * 2, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * input_dimensions[1] * input_dimensions[2], 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, chance_encoding_dim)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.view(x.shape[0], -1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class ChanceEncoderBackboneMLP(nn.Module):
"""
Overview:
The definition of chance encoder backbone network, \
which is used to encode the (vector) observation into a latent space.
Arguments:
- input_dimensions (:obj:`tuple`): The dimension of observation space.
- chance_encoding_dim (:obj:`int`): The dimension of chance encoding.
"""
def __init__(self, input_dimensions, chance_encoding_dim=4):
super(ChanceEncoderBackboneMLP, self).__init__()
self.fc1 = nn.Linear(input_dimensions, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, chance_encoding_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class ChanceEncoder(nn.Module):
def __init__(self, input_dimensions, action_dimension, encoder_backbone_type='conv'):
super().__init__()
# Specify the action space for the model
self.action_space = action_dimension
if encoder_backbone_type == 'conv':
# Define the encoder, which transforms observations into a latent space
self.encoder = ChanceEncoderBackbone(input_dimensions, action_dimension)
elif encoder_backbone_type == 'mlp':
self.encoder = ChanceEncoderBackboneMLP(input_dimensions, action_dimension)
else:
raise ValueError('Encoder backbone type not supported')
# Using the Straight Through Estimator method for backpropagation
self.onehot_argmax = StraightThroughEstimator()
def forward(self, observations):
"""
Overview:
Forward method for the ChanceEncoder. This method takes an observation \
and applies the encoder to transform it to a latent space. Then applies the \
StraightThroughEstimator to this encoding. \
References: Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5,
Chance Outcomes section.
Arguments:
- observations (:obj:`torch.Tensor`): Observation tensor.
Returns:
- chance (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax.
- chance_encoding (:obj:`torch.Tensor`): Encoding of the input observation tensor.
"""
# Apply the encoder to the observation
chance_encoding = self.encoder(observations)
# Apply one-hot argmax to the encoding
chance_onehot = self.onehot_argmax(chance_encoding)
return chance_encoding, chance_onehot
class StraightThroughEstimator(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
"""
Overview:
Forward method for the StraightThroughEstimator. This applies the one-hot argmax \
function to the input tensor.
Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
Returns:
- (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax.
"""
# Apply one-hot argmax to the input
x = OnehotArgmax.apply(x)
return x
class OnehotArgmax(torch.autograd.Function):
"""
Overview:
Custom PyTorch function for one-hot argmax. This function transforms the input tensor \
into a one-hot tensor where the index with the maximum value in the original tensor is \
set to 1 and all other indices are set to 0. It allows gradients to flow to the encoder \
during backpropagation.
For more information, refer to: \
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
"""
@staticmethod
def forward(ctx, input):
"""
Overview:
Forward method for the one-hot argmax function. This method transforms the input \
tensor into a one-hot tensor.
Arguments:
- ctx (:obj:`context`): A context object that can be used to stash information for
backward computation.
- input (:obj:`torch.Tensor`): Input tensor.
Returns:
- (:obj:`torch.Tensor`): One-hot tensor.
"""
# Transform the input tensor to a one-hot tensor
return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1, keepdim=True), 1.)
@staticmethod
def backward(ctx, grad_output):
"""
Overview:
Backward method for the one-hot argmax function. This method allows gradients \
to flow to the encoder during backpropagation.
Arguments:
- ctx (:obj:`context`): A context object that was stashed in the forward pass.
- grad_output (:obj:`torch.Tensor`): The gradient of the output tensor.
Returns:
- (:obj:`torch.Tensor`): The gradient of the input tensor.
"""
return grad_output