zjowowen's picture
init space
079c32c
raw
history blame
No virus
1.79 kB
import torch.nn as nn
from ding.utils import MODEL_REGISTRY
from .qmix import QMix
@MODEL_REGISTRY.register('madqn')
class MADQN(nn.Module):
def __init__(
self,
agent_num: int,
obs_shape: int,
action_shape: int,
hidden_size_list: list,
global_obs_shape: int = None,
mixer: bool = False,
global_cooperation: bool = True,
lstm_type: str = 'gru',
dueling: bool = False
) -> None:
super(MADQN, self).__init__()
self.current = QMix(
agent_num=agent_num,
obs_shape=obs_shape,
action_shape=action_shape,
hidden_size_list=hidden_size_list,
global_obs_shape=global_obs_shape,
mixer=mixer,
lstm_type=lstm_type,
dueling=dueling
)
self.global_cooperation = global_cooperation
if self.global_cooperation:
cooperation_obs_shape = global_obs_shape
else:
cooperation_obs_shape = obs_shape
self.cooperation = QMix(
agent_num=agent_num,
obs_shape=cooperation_obs_shape,
action_shape=action_shape,
hidden_size_list=hidden_size_list,
global_obs_shape=global_obs_shape,
mixer=mixer,
lstm_type=lstm_type,
dueling=dueling
)
def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict:
if cooperation:
if self.global_cooperation:
data['obs']['agent_state'] = data['obs']['global_state']
return self.cooperation(data, single_step=single_step)
else:
return self.current(data, single_step=single_step)