File size: 7,078 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from tabnanny import check
from typing import Any, Callable, List, Tuple
import numpy as np
from collections.abc import Sequence
from easydict import EasyDict

from ding.envs.env import BaseEnv, BaseEnvTimestep
from ding.envs.env.tests import DemoEnv
# from dizoo.atari.envs import AtariEnv


def check_space_dtype(env: BaseEnv) -> None:
    print("== 0. Test obs/act/rew space's dtype")
    env.reset()
    for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]):
        if 'float' in repr(space.dtype):
            assert space.dtype == np.float32, "If float, then must be np.float32, but get {} for {} space".format(
                space.dtype, name
            )
        if 'int' in repr(space.dtype):
            assert space.dtype == np.int64, "If int, then must be np.int64, but get {} for {} space".format(
                space.dtype, name
            )


# Util function
def check_array_space(ndarray, space, name) -> bool:
    if isinstance(ndarray, np.ndarray):
        # print("{}'s type should be np.ndarray".format(name))
        assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(
            name, ndarray.dtype, space.dtype
        )
        assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format(
            name, ndarray.shape, space.shape
        )
        assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
        ), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
    elif isinstance(ndarray, Sequence):
        for i in range(len(ndarray)):
            try:
                check_array_space(ndarray[i], space[i], name)
            except AssertionError as e:
                print("The following  error happens at {}-th index".format(i))
                raise e
    elif isinstance(ndarray, dict):
        for k in ndarray.keys():
            try:
                check_array_space(ndarray[k], space[k], name)
            except AssertionError as e:
                print("The following  error happens at key {}".format(k))
                raise e
    else:
        raise TypeError(
            "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray))
        )


def check_reset(env: BaseEnv) -> None:
    print('== 1. Test reset method')
    obs = env.reset()
    check_array_space(obs, env.observation_space, 'obs')


def check_step(env: BaseEnv) -> None:
    done_times = 0
    print('== 2. Test step method')
    _ = env.reset()
    if hasattr(env, "random_action"):
        random_action = env.random_action()
    else:
        random_action = env.action_space.sample()
    while True:
        obs, rew, done, info = env.step(random_action)
        for ndarray, space, name in zip([obs, rew], [env.observation_space, env.reward_space], ['obs', 'rew']):
            check_array_space(ndarray, space, name)
        if done:
            assert 'eval_episode_return' in info, "info dict should have 'eval_episode_return' key."
            done_times += 1
            _ = env.reset()
        if done_times == 3:
            break


# Util function
def check_different_memory(array1, array2, step_times) -> None:
    assert type(array1) == type(
        array2
    ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format(
        step_times, type(array1), type(array2)
    )
    if isinstance(array1, np.ndarray):
        assert id(array1) != id(
            array2
        ), "In step times {}, obs_last_frame and obs_this_frame are the same np.ndarray".format(step_times)
    elif isinstance(array1, Sequence):
        assert len(array1) == len(
            array2
        ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have different sequence lengths".format(
            step_times, len(array1), len(array2)
        )
        for i in range(len(array1)):
            try:
                check_different_memory(array1[i], array2[i], step_times)
            except AssertionError as e:
                print("The following error happens at {}-th index".format(i))
                raise e
    elif isinstance(array1, dict):
        assert array1.keys() == array2.keys(), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have \
                different dict keys".format(step_times, array1.keys(), array2.keys())
        for k in array1.keys():
            try:
                check_different_memory(array1[k], array2[k], step_times)
            except AssertionError as e:
                print("The following  error happens at key {}".format(k))
                raise e
    else:
        raise TypeError(
            "Input array should be np.ndarray or list/dict of np.ndarray, but found {} and {}".format(
                type(array1), type(array2)
            )
        )


def check_obs_deepcopy(env: BaseEnv) -> None:

    step_times = 0
    print('== 3. Test observation deepcopy')
    obs_1 = env.reset()
    if hasattr(env, "random_action"):
        random_action = env.random_action()
    else:
        random_action = env.action_space.sample()
    while True:
        step_times += 1
        obs_2, _, done, _ = env.step(random_action)
        check_different_memory(obs_1, obs_2, step_times)
        obs_1 = obs_2
        if done:
            break


def check_all(env: BaseEnv) -> None:
    check_space_dtype(env)
    check_reset(env)
    check_step(env)
    check_obs_deepcopy(env)


def demonstrate_correct_procedure(env_fn: Callable) -> None:
    print('== 4. Demonstrate the correct procudures')
    done_times = 0
    # Init the env.
    env = env_fn({})
    # Lazy init. The real env is not initialized until `reset` method is called
    assert not hasattr(env, "_env")
    # Must set seed before `reset` method is called.
    env.seed(4)
    assert env._seed == 4
    # Reset the env. The real env is initialized here.
    obs = env.reset()
    while True:
        # Using the policy to get the action from obs. But here we use `random_action` instead.
        action = env.random_action()
        obs, rew, done, info = env.step(action)
        if done:
            assert 'eval_episode_return' in info
            done_times += 1
            obs = env.reset()
            # Seed will not change unless `seed` method is called again.
            assert env._seed == 4
        if done_times == 3:
            break


if __name__ == "__main__":
    '''
    # Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules.
    # You can replace `AtariEnv` with your own env.
    atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False))
    check_reset(atari_env)
    check_step(atari_env)
    check_obs_deepcopy(atari_env)
    '''
    # Method `demonstrate_correct_procudure` is to demonstrate the correct procedure to
    # use an env to generate trajectories.
    # You can check whether your env's design is similar to `DemoEnv`
    demonstrate_correct_procedure(DemoEnv)