File size: 1,666 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from collections import namedtuple
import numpy as np
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.utils import ENV_REGISTRY
FakeSMACEnvTimestep = namedtuple('FakeSMACEnvTimestep', ['obs', 'reward', 'done', 'info'])
FakeSMACEnvInfo = namedtuple('FakeSMACEnvInfo', ['agent_num', 'obs_space', 'act_space', 'rew_space'])
@ENV_REGISTRY.register('fake_smac')
class FakeSMACEnv(BaseEnv):
def __init__(self, cfg=None):
self.agent_num = 8
self.action_dim = 6 + self.agent_num
self.obs_dim = 248
self.obs_alone_dim = 216
self.global_obs_dim = 216
def reset(self):
self.step_count = 0
return self._get_obs()
def _get_obs(self):
return {
'agent_state': np.random.random((self.agent_num, self.obs_dim)),
'agent_alone_state': np.random.random((self.agent_num, self.obs_alone_dim)),
'agent_alone_padding_state': np.random.random((self.agent_num, self.obs_dim)),
'global_state': np.random.random((self.global_obs_dim)),
'action_mask': np.random.randint(0, 2, size=(self.agent_num, self.action_dim)),
}
def step(self, action):
assert action.shape == (self.agent_num, ), action.shape
obs = self._get_obs()
reward = np.random.randint(0, 10, size=(1, ))
done = self.step_count >= 314
info = {}
if done:
info['eval_episode_return'] = 0.71
self.step_count += 1
return FakeSMACEnvTimestep(obs, reward, done, info)
def close(self):
pass
def seed(self, _seed):
pass
def __repr__(self):
return 'FakeSMACEnv'
|