File size: 1,467 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from time import time
import pytest
import numpy as np
from easydict import EasyDict
from dizoo.bsuite.envs import BSuiteEnv
@pytest.mark.envtest
class TestBSuiteEnv:
def test_memory_len(self):
cfg = {'env_id': 'memory_len/0'}
cfg = EasyDict(cfg)
memory_len_env = BSuiteEnv(cfg)
memory_len_env.seed(0)
obs = memory_len_env.reset()
assert obs.shape == (3, )
while True:
random_action = memory_len_env.random_action()
timestep = memory_len_env.step(random_action)
assert timestep.obs.shape == (3, )
assert timestep.reward.shape == (1, )
if timestep.done:
assert 'eval_episode_return' in timestep.info, timestep.info
break
memory_len_env.close()
def test_cartpole_swingup(self):
cfg = {'env_id': 'cartpole_swingup/0'}
cfg = EasyDict(cfg)
bandit_noise_env = BSuiteEnv(cfg)
bandit_noise_env.seed(0)
obs = bandit_noise_env.reset()
assert obs.shape == (8, )
while True:
random_action = bandit_noise_env.random_action()
timestep = bandit_noise_env.step(random_action)
assert timestep.obs.shape == (8, )
assert timestep.reward.shape == (1, )
if timestep.done:
assert 'eval_episode_return' in timestep.info, timestep.info
break
bandit_noise_env.close()
|