File size: 6,050 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import pytest
import numpy as np
from easydict import EasyDict

from dizoo.smac.envs import SMACEnv

MOVE_EAST = 4
MOVE_WEST = 5


def automation(env, n_agents):
    actions = {"me": [], "opponent": []}
    for agent_id in range(n_agents):
        avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False)
        avail_actions_ind = np.nonzero(avail_actions)[0]
        action = np.random.choice(avail_actions_ind)
        if avail_actions[0] != 0:
            action = 0
        elif len(np.nonzero(avail_actions[6:])[0]) == 0:
            if avail_actions[MOVE_EAST] != 0:
                action = MOVE_EAST
            else:
                action = np.random.choice(avail_actions_ind)
        else:
            action = np.random.choice(avail_actions_ind)
        # if MOVE_EAST in avail_actions_ind:
        #     action = MOVE_EAST
        # Let OPPONENT attack ME at the first place
        # if sum(avail_actions[6:]) > 0:
        #     action = max(avail_actions_ind)
        # print("ME start attacking OP")
        # print("Available action for ME: ", avail_actions_ind)
        actions["me"].append(action)
        print('ava', avail_actions, action)
    for agent_id in range(n_agents):
        avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True)
        avail_actions_ind = np.nonzero(avail_actions)[0]
        action = np.random.choice(avail_actions_ind)
        if MOVE_EAST in avail_actions_ind:
            action = MOVE_EAST
        # Let OPPONENT attack ME at the first place
        if sum(avail_actions[6:]) > 0:
            # print("OP start attacking ME")
            action = max(avail_actions_ind)
        actions["opponent"].append(action)
    return actions


def random_policy(env, n_agents):
    actions = {"me": [], "opponent": []}
    for agent_id in range(n_agents):
        avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False)
        avail_actions_ind = np.nonzero(avail_actions)[0]
        action = np.random.choice(avail_actions_ind)
        actions["me"].append(action)
    for agent_id in range(n_agents):
        avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True)
        avail_actions_ind = np.nonzero(avail_actions)[0]
        # Move left to kill ME
        action = np.random.choice(avail_actions_ind)
        actions["opponent"].append(action)
    return actions


def fix_policy(env, n_agents, me=0, opponent=0):
    actions = {"me": [], "opponent": []}
    for agent_id in range(n_agents):
        avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False)
        avail_actions_ind = np.nonzero(avail_actions)[0]
        action = me
        if action not in avail_actions_ind:
            action = avail_actions_ind[0]
        actions["me"].append(action)

    for agent_id in range(n_agents):
        avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True)
        avail_actions_ind = np.nonzero(avail_actions)[0]
        action = opponent
        if action not in avail_actions_ind:
            action = avail_actions_ind[0]
        actions["opponent"].append(action)
    return actions


def main(policy, map_name="3m", two_player=False):
    cfg = EasyDict({'two_player': two_player, 'map_name': map_name, 'save_replay_episodes': None, 'obs_alone': True})
    env = SMACEnv(cfg)
    if map_name == "3s5z":
        n_agents = 8
    elif map_name == "3m":
        n_agents = 3
    elif map_name == "infestor_viper":
        n_agents = 2
    else:
        raise ValueError(f"invalid type: {map_name}")
    n_episodes = 20
    me_win = 0
    draw = 0
    op_win = 0

    for e in range(n_episodes):
        print("Now reset the environment for {} episode.".format(e))
        env.reset()
        print('reset over')
        terminated = False
        episode_return_me = 0
        episode_return_op = 0

        env_info = env.info()
        print('begin new episode')
        while not terminated:
            actions = policy(env, n_agents)
            if not two_player:
                actions = actions["me"]
            t = env.step(actions)
            obs, reward, terminated, infos = t.obs, t.reward, t.done, t.info
            assert set(obs.keys()) == set(
                ['agent_state', 'global_state', 'action_mask', 'agent_alone_state', 'agent_alone_padding_state']
            )
            assert isinstance(obs['agent_state'], np.ndarray)
            assert obs['agent_state'].shape == env_info.obs_space.shape['agent_state']  # n_agents, agent_state_dim
            assert isinstance(obs['agent_alone_state'], np.ndarray)
            assert obs['agent_alone_state'].shape == env_info.obs_space.shape['agent_alone_state']
            assert isinstance(obs['global_state'], np.ndarray)
            assert obs['global_state'].shape == env_info.obs_space.shape['global_state']  # global_state_dim
            assert isinstance(reward, np.ndarray)
            assert reward.shape == (1, )
            print('reward', reward)
            assert isinstance(terminated, bool)
            episode_return_me += reward["me"] if two_player else reward
            episode_return_op += reward["opponent"] if two_player else 0
            terminated = terminated["me"] if two_player else terminated

        if two_player:
            me_win += int(infos["me"]["battle_won"])
            op_win += int(infos["opponent"]["battle_won"])
            draw += int(infos["draw"])
        else:
            me_win += int(infos["battle_won"])
            op_win += int(infos["battle_lost"])
            draw += int(infos["draw"])

        print(
            "Total return in episode {} = {} (me), {} (opponent). Me win {}, Draw {}, Opponent win {}, total {}."
            "".format(e, episode_return_me, episode_return_op, me_win, draw, op_win, e + 1)
        )

    env.close()


@pytest.mark.env_test
def test_automation():
    # main(automation, map_name="3m", two_player=False)
    main(automation, map_name="infestor_viper", two_player=False)


if __name__ == "__main__":
    test_automation()