File size: 2,226 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pytest
import numpy as np
from easydict import EasyDict
from torch import rand
from dizoo.classic_control.pendulum.envs import PendulumEnv


@pytest.mark.envtest
class TestPendulumEnv:

    def test_naive(self):
        env = PendulumEnv(EasyDict({'act_scale': True}))
        env.seed(314)
        assert env._seed == 314
        obs = env.reset()
        assert obs.shape == (3, )
        for i in range(10):
            # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
            # can generate legal random action.
            if i < 5:
                random_action = np.tanh(np.random.random(1))
            else:
                random_action = env.random_action()
            timestep = env.step(random_action)
            assert timestep.obs.shape == (3, )
            assert timestep.reward.shape == (1, )
            assert timestep.reward >= env.reward_space.low
            assert timestep.reward <= env.reward_space.high
            # assert isinstance(timestep, tuple)
        print(env.observation_space, env.action_space, env.reward_space)
        env.close()

    def test_discrete(self):
        env = PendulumEnv(EasyDict({'act_scale': True, 'continuous': False}))
        env.seed(314)
        assert env._seed == 314
        obs = env.reset()
        assert obs.shape == (3, )
        for i in range(10):
            # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
            # can generate legal random action.
            if i < 5:
                random_action = np.array([env.action_space.sample()])
            else:
                random_action = env.random_action()
            timestep = env.step(random_action)
            print(env.observation_space, env.action_space, env.reward_space)
            print(timestep.reward, timestep.obs, timestep.reward)
            assert timestep.reward.shape == (1, )
            assert timestep.obs.shape == (3, )
            assert timestep.reward >= env.reward_space.low
            assert timestep.reward <= env.reward_space.high
            # assert isinstance(timestep, tuple)
        print(env.observation_space, env.action_space, env.reward_space)
        env.close()