import pytest import numpy as np from easydict import EasyDict from dizoo.smac.envs import SMACEnv MOVE_EAST = 4 MOVE_WEST = 5 def automation(env, n_agents): actions = {"me": [], "opponent": []} for agent_id in range(n_agents): avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False) avail_actions_ind = np.nonzero(avail_actions)[0] action = np.random.choice(avail_actions_ind) if avail_actions[0] != 0: action = 0 elif len(np.nonzero(avail_actions[6:])[0]) == 0: if avail_actions[MOVE_EAST] != 0: action = MOVE_EAST else: action = np.random.choice(avail_actions_ind) else: action = np.random.choice(avail_actions_ind) # if MOVE_EAST in avail_actions_ind: # action = MOVE_EAST # Let OPPONENT attack ME at the first place # if sum(avail_actions[6:]) > 0: # action = max(avail_actions_ind) # print("ME start attacking OP") # print("Available action for ME: ", avail_actions_ind) actions["me"].append(action) print('ava', avail_actions, action) for agent_id in range(n_agents): avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True) avail_actions_ind = np.nonzero(avail_actions)[0] action = np.random.choice(avail_actions_ind) if MOVE_EAST in avail_actions_ind: action = MOVE_EAST # Let OPPONENT attack ME at the first place if sum(avail_actions[6:]) > 0: # print("OP start attacking ME") action = max(avail_actions_ind) actions["opponent"].append(action) return actions def random_policy(env, n_agents): actions = {"me": [], "opponent": []} for agent_id in range(n_agents): avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False) avail_actions_ind = np.nonzero(avail_actions)[0] action = np.random.choice(avail_actions_ind) actions["me"].append(action) for agent_id in range(n_agents): avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True) avail_actions_ind = np.nonzero(avail_actions)[0] # Move left to kill ME action = np.random.choice(avail_actions_ind) actions["opponent"].append(action) return actions def fix_policy(env, n_agents, me=0, opponent=0): actions = {"me": [], "opponent": []} for agent_id in range(n_agents): avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False) avail_actions_ind = np.nonzero(avail_actions)[0] action = me if action not in avail_actions_ind: action = avail_actions_ind[0] actions["me"].append(action) for agent_id in range(n_agents): avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True) avail_actions_ind = np.nonzero(avail_actions)[0] action = opponent if action not in avail_actions_ind: action = avail_actions_ind[0] actions["opponent"].append(action) return actions def main(policy, map_name="3m", two_player=False): cfg = EasyDict({'two_player': two_player, 'map_name': map_name, 'save_replay_episodes': None, 'obs_alone': True}) env = SMACEnv(cfg) if map_name == "3s5z": n_agents = 8 elif map_name == "3m": n_agents = 3 elif map_name == "infestor_viper": n_agents = 2 else: raise ValueError(f"invalid type: {map_name}") n_episodes = 20 me_win = 0 draw = 0 op_win = 0 for e in range(n_episodes): print("Now reset the environment for {} episode.".format(e)) env.reset() print('reset over') terminated = False episode_return_me = 0 episode_return_op = 0 env_info = env.info() print('begin new episode') while not terminated: actions = policy(env, n_agents) if not two_player: actions = actions["me"] t = env.step(actions) obs, reward, terminated, infos = t.obs, t.reward, t.done, t.info assert set(obs.keys()) == set( ['agent_state', 'global_state', 'action_mask', 'agent_alone_state', 'agent_alone_padding_state'] ) assert isinstance(obs['agent_state'], np.ndarray) assert obs['agent_state'].shape == env_info.obs_space.shape['agent_state'] # n_agents, agent_state_dim assert isinstance(obs['agent_alone_state'], np.ndarray) assert obs['agent_alone_state'].shape == env_info.obs_space.shape['agent_alone_state'] assert isinstance(obs['global_state'], np.ndarray) assert obs['global_state'].shape == env_info.obs_space.shape['global_state'] # global_state_dim assert isinstance(reward, np.ndarray) assert reward.shape == (1, ) print('reward', reward) assert isinstance(terminated, bool) episode_return_me += reward["me"] if two_player else reward episode_return_op += reward["opponent"] if two_player else 0 terminated = terminated["me"] if two_player else terminated if two_player: me_win += int(infos["me"]["battle_won"]) op_win += int(infos["opponent"]["battle_won"]) draw += int(infos["draw"]) else: me_win += int(infos["battle_won"]) op_win += int(infos["battle_lost"]) draw += int(infos["draw"]) print( "Total return in episode {} = {} (me), {} (opponent). Me win {}, Draw {}, Opponent win {}, total {}." "".format(e, episode_return_me, episode_return_op, me_win, draw, op_win, e + 1) ) env.close() @pytest.mark.env_test def test_automation(): # main(automation, map_name="3m", two_player=False) main(automation, map_name="infestor_viper", two_player=False) if __name__ == "__main__": test_automation()