import torch.nn as nn from ding.utils import MODEL_REGISTRY from .qmix import QMix @MODEL_REGISTRY.register('madqn') class MADQN(nn.Module): def __init__( self, agent_num: int, obs_shape: int, action_shape: int, hidden_size_list: list, global_obs_shape: int = None, mixer: bool = False, global_cooperation: bool = True, lstm_type: str = 'gru', dueling: bool = False ) -> None: super(MADQN, self).__init__() self.current = QMix( agent_num=agent_num, obs_shape=obs_shape, action_shape=action_shape, hidden_size_list=hidden_size_list, global_obs_shape=global_obs_shape, mixer=mixer, lstm_type=lstm_type, dueling=dueling ) self.global_cooperation = global_cooperation if self.global_cooperation: cooperation_obs_shape = global_obs_shape else: cooperation_obs_shape = obs_shape self.cooperation = QMix( agent_num=agent_num, obs_shape=cooperation_obs_shape, action_shape=action_shape, hidden_size_list=hidden_size_list, global_obs_shape=global_obs_shape, mixer=mixer, lstm_type=lstm_type, dueling=dueling ) def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict: if cooperation: if self.global_cooperation: data['obs']['agent_state'] = data['obs']['global_state'] return self.cooperation(data, single_step=single_step) else: return self.current(data, single_step=single_step)