from time import time import pytest import numpy as np from easydict import EasyDict from dizoo.bsuite.envs import BSuiteEnv @pytest.mark.envtest class TestBSuiteEnv: def test_memory_len(self): cfg = {'env_id': 'memory_len/0'} cfg = EasyDict(cfg) memory_len_env = BSuiteEnv(cfg) memory_len_env.seed(0) obs = memory_len_env.reset() assert obs.shape == (3, ) while True: random_action = memory_len_env.random_action() timestep = memory_len_env.step(random_action) assert timestep.obs.shape == (3, ) assert timestep.reward.shape == (1, ) if timestep.done: assert 'eval_episode_return' in timestep.info, timestep.info break memory_len_env.close() def test_cartpole_swingup(self): cfg = {'env_id': 'cartpole_swingup/0'} cfg = EasyDict(cfg) bandit_noise_env = BSuiteEnv(cfg) bandit_noise_env.seed(0) obs = bandit_noise_env.reset() assert obs.shape == (8, ) while True: random_action = bandit_noise_env.random_action() timestep = bandit_noise_env.step(random_action) assert timestep.obs.shape == (8, ) assert timestep.reward.shape == (1, ) if timestep.done: assert 'eval_episode_return' in timestep.info, timestep.info break bandit_noise_env.close()