from collections import namedtuple import numpy as np from ding.envs import BaseEnv, BaseEnvTimestep from ding.utils import ENV_REGISTRY FakeSMACEnvTimestep = namedtuple('FakeSMACEnvTimestep', ['obs', 'reward', 'done', 'info']) FakeSMACEnvInfo = namedtuple('FakeSMACEnvInfo', ['agent_num', 'obs_space', 'act_space', 'rew_space']) @ENV_REGISTRY.register('fake_smac') class FakeSMACEnv(BaseEnv): def __init__(self, cfg=None): self.agent_num = 8 self.action_dim = 6 + self.agent_num self.obs_dim = 248 self.obs_alone_dim = 216 self.global_obs_dim = 216 def reset(self): self.step_count = 0 return self._get_obs() def _get_obs(self): return { 'agent_state': np.random.random((self.agent_num, self.obs_dim)), 'agent_alone_state': np.random.random((self.agent_num, self.obs_alone_dim)), 'agent_alone_padding_state': np.random.random((self.agent_num, self.obs_dim)), 'global_state': np.random.random((self.global_obs_dim)), 'action_mask': np.random.randint(0, 2, size=(self.agent_num, self.action_dim)), } def step(self, action): assert action.shape == (self.agent_num, ), action.shape obs = self._get_obs() reward = np.random.randint(0, 10, size=(1, )) done = self.step_count >= 314 info = {} if done: info['eval_episode_return'] = 0.71 self.step_count += 1 return FakeSMACEnvTimestep(obs, reward, done, info) def close(self): pass def seed(self, _seed): pass def __repr__(self): return 'FakeSMACEnv'