File size: 1,992 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np

import gym
from gym.envs.mujoco.ant import AntEnv
from gym.envs.mujoco.humanoid import HumanoidEnv


def gym_env_register(id, max_episode_steps=1000):

    def register(gym_env):
        spec = {
            'id': id,
            'entry_point': (f'dizoo.mujoco.envs.mujoco_gym_env:{gym_env.__name__}'),
            'max_episode_steps': max_episode_steps
        }
        gym.register(**spec)
        return gym_env

    return register


@gym_env_register('AntTruncatedObs-v2')
class AntTruncatedObsEnv(AntEnv):
    """
    Overview:
        Modified ant with observation dim truncated to 27, which is used in MBPO (arXiv: 1906.08253).
    .. note::
        External forces (sim.data.cfrc_ext) are removed from the observation.
        Otherwise identical to Ant-v2 from\
        <https://github.com/openai/gym/blob/master/gym/envs/mujoco/ant.py>.
    """

    def _get_obs(self):
        return np.concatenate(
            [
                self.sim.data.qpos.flat[2:],
                self.sim.data.qvel.flat,
                # np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
            ]
        )


@gym_env_register('HumanoidTruncatedObs-v2')
class HumanoidTruncatedObsEnv(HumanoidEnv):
    """
    Overview:
        Modified humanoid with observation dim truncated to 45, which is used in MBPO (arXiv: 1906.08253). 
    .. note::
        COM inertia (cinert), COM velocity (cvel), actuator forces (qfrc_actuator),\
        and external forces (cfrc_ext) are removed from the observation.
        Otherwise identical to Humanoid-v2 from\
        <https://github.com/openai/gym/blob/master/gym/envs/mujoco/humanoid.py>.
    """

    def _get_obs(self):
        data = self.sim.data
        return np.concatenate(
            [
                data.qpos.flat[2:],
                data.qvel.flat,
                # data.cinert.flat,
                # data.cvel.flat,
                # data.qfrc_actuator.flat,
                # data.cfrc_ext.flat
            ]
        )