File size: 1,467 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from time import time
import pytest
import numpy as np
from easydict import EasyDict
from dizoo.bsuite.envs import BSuiteEnv


@pytest.mark.envtest
class TestBSuiteEnv:

    def test_memory_len(self):
        cfg = {'env_id': 'memory_len/0'}
        cfg = EasyDict(cfg)
        memory_len_env = BSuiteEnv(cfg)
        memory_len_env.seed(0)
        obs = memory_len_env.reset()
        assert obs.shape == (3, )
        while True:
            random_action = memory_len_env.random_action()
            timestep = memory_len_env.step(random_action)
            assert timestep.obs.shape == (3, )
            assert timestep.reward.shape == (1, )
            if timestep.done:
                assert 'eval_episode_return' in timestep.info, timestep.info
                break
        memory_len_env.close()

    def test_cartpole_swingup(self):
        cfg = {'env_id': 'cartpole_swingup/0'}
        cfg = EasyDict(cfg)
        bandit_noise_env = BSuiteEnv(cfg)
        bandit_noise_env.seed(0)
        obs = bandit_noise_env.reset()
        assert obs.shape == (8, )
        while True:
            random_action = bandit_noise_env.random_action()
            timestep = bandit_noise_env.step(random_action)
            assert timestep.obs.shape == (8, )
            assert timestep.reward.shape == (1, )
            if timestep.done:
                assert 'eval_episode_return' in timestep.info, timestep.info
                break
        bandit_noise_env.close()