File size: 3,342 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pytest
import numpy as np
import pprint
from easydict import EasyDict

try:
    from dizoo.gfootball.envs.gfootball_academy_env import GfootballAcademyEnv
except ModuleNotFoundError:
    print("[WARNING] no gfootball env, if you want to use gfootball, please install it, otherwise, ignore it.")

cfg_keeper = EasyDict(dict(
    env_name='academy_3_vs_1_with_keeper',
    agent_num=3,
    obs_dim=26,
))

cfg_counter = EasyDict(dict(
    env_name='academy_counterattack_hard',
    agent_num=4,
    obs_dim=34,
))


@pytest.mark.envtest
class TestGfootballAcademyEnv:

    def get_random_action(self, min_value, max_value):
        action = np.random.randint(min_value, max_value + 1, (1, ))
        return action

    def test_academy_3_vs_1_with_keeper(self):
        cfg = cfg_keeper
        env = GfootballAcademyEnv(cfg)
        print(env.observation_space, env._action_space, env.reward_space)
        pp = pprint.PrettyPrinter(indent=2)
        for i in range(2):
            eps_len = 0
            # env.enable_save_replay(replay_path='./video')
            reset_obs = env.reset()
            while True:
                eps_len += 1
                action = env.random_action()[0]
                action = [int(action_agent) for k, action_agent in action.items()]
                timestep = env.step(action)
                obs = timestep.obs
                reward = timestep.reward
                done = timestep.done
                # print('observation: ')
                # pp.pprint(obs)
                assert obs['agent_state'].shape == (cfg.agent_num, cfg.obs_dim)
                assert obs['global_state'].shape == (cfg.agent_num, cfg.obs_dim * 2)
                assert obs['action_mask'].shape == (cfg.agent_num, 19)

                print('step {}, action: {}, reward: {}'.format(eps_len, action, reward))
                if done:
                    break
            assert reward == -1 or reward == 100
            print(f'Episode {i} done! The episode length is {eps_len}. The last reward is {reward}.')
        print('End')

    def test_academy_counterattack_hard(self):
        cfg = cfg_counter
        env = GfootballAcademyEnv(cfg)
        print(env.observation_space, env._action_space, env.reward_space)
        pp = pprint.PrettyPrinter(indent=2)
        for i in range(2):
            eps_len = 0
            reset_obs = env.reset()
            while True:
                eps_len += 1
                action = env.random_action()[0]
                action = [int(action_agent) for k, action_agent in action.items()]
                timestep = env.step(action)
                obs = timestep.obs
                reward = timestep.reward
                done = timestep.done
                # print('observation: ')
                # pp.pprint(obs)
                assert obs['agent_state'].shape == (cfg.agent_num, cfg.obs_dim)
                assert obs['global_state'].shape == (cfg.agent_num, cfg.obs_dim * 2)
                assert obs['action_mask'].shape == (cfg.agent_num, 19)

                print('step {}, action: {}, reward: {}'.format(eps_len, action, reward))
                if done:
                    break
            assert reward == -1 or reward == 100
            print(f'Episode {i} done! The episode length is {eps_len}. The last reward is {reward}.')
        print('End')