File size: 8,597 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import cv2
import gym
import os.path as osp
import numpy as np
from typing import Union, Optional
from collections import deque
from competitive_rl.pong.builtin_policies import get_builtin_agent_names, single_obs_space, single_act_space, get_random_policy, get_rule_based_policy
from competitive_rl.utils.policy_serving import Policy


def get_compute_action_function_ours(agent_name, num_envs=1):
    resource_dir = osp.join(osp.dirname(__file__), "resources", "pong")
    if agent_name == "STRONG":
        return Policy(
            single_obs_space,
            single_act_space,
            num_envs,
            osp.join(resource_dir, "checkpoint-strong.pkl"),
            use_light_model=False
        )
    if agent_name == "MEDIUM":
        return Policy(
            single_obs_space,
            single_act_space,
            num_envs,
            osp.join(resource_dir, "checkpoint-medium.pkl"),
            use_light_model=True
        )
    if agent_name == "ALPHA_PONG":
        return Policy(
            single_obs_space,
            single_act_space,
            num_envs,
            osp.join(resource_dir, "checkpoint-alphapong.pkl"),
            use_light_model=False
        )
    if agent_name == "WEAK":
        return Policy(
            single_obs_space,
            single_act_space,
            num_envs,
            osp.join(resource_dir, "checkpoint-weak.pkl"),
            use_light_model=True
        )
    if agent_name == "RANDOM":
        return get_random_policy(num_envs)
    if agent_name == "RULE_BASED":
        return get_rule_based_policy(num_envs)
    raise ValueError("Unknown agent name: {}".format(agent_name))


class BuiltinOpponentWrapper(gym.Wrapper):

    def __init__(self, env: 'gym.Env', num_envs: int = 1) -> None:  # noqa
        super().__init__(env)
        self.agents = {
            agent_name: get_compute_action_function_ours(agent_name, num_envs)
            for agent_name in get_builtin_agent_names()
        }
        self.agent_names = list(self.agents)
        self.prev_opponent_obs = None
        self.current_opponent_name = "RULE_BASED"
        self.current_opponent = self.agents[self.current_opponent_name]
        self.observation_space = env.observation_space[0]
        self.action_space = env.action_space[0]
        self.num_envs = num_envs

    def reset_opponent(self, agent_name: str) -> None:
        assert agent_name in self.agent_names, (agent_name, self.agent_names)
        self.current_opponent_name = agent_name
        self.current_opponent = self.agents[self.current_opponent_name]

    def step(self, action):
        tuple_action = (action.item(), self.current_opponent(self.prev_opponent_obs))
        obs, rew, done, info = self.env.step(tuple_action)
        self.prev_opponent_obs = obs[1]
        # if done.ndim == 2:
        #     done = done[:, 0]
        # return obs[0], rew[:, 0].reshape(-1, 1), done.reshape(-1, 1), info
        return obs[0], rew[0], done, info

    def reset(self):
        obs = self.env.reset()
        self.prev_opponent_obs = obs[1]
        return obs[0]

    def seed(self, s):
        self.env.seed(s)


def wrap_env(env_id, builtin_wrap, opponent, frame_stack=4, warp_frame=True, only_info=False):
    """Configure environment for DeepMind-style Atari. The observation is
    channel-first: (c, h, w) instead of (h, w, c).

    :param str env_id: the atari environment id.
    :param bool episode_life: wrap the episode life wrapper.
    :param bool clip_rewards: wrap the reward clipping wrapper.
    :param int frame_stack: wrap the frame stacking wrapper.
    :param bool scale: wrap the scaling observation wrapper.
    :param bool warp_frame: wrap the grayscale + resize observation wrapper.
    :return: the wrapped atari environment.
    """
    if not only_info:
        env = gym.make(env_id)
        if builtin_wrap:
            env = BuiltinOpponentWrapper(env)
            env.reset_opponent(opponent)

        if warp_frame:
            env = WarpFrameWrapperCompetitveRl(env, builtin_wrap)
        if frame_stack:
            env = FrameStackWrapperCompetitiveRl(env, frame_stack, builtin_wrap)
        return env
    else:
        wrapper_info = ''
        if builtin_wrap:
            wrapper_info += BuiltinOpponentWrapper.__name__ + '\n'
        if warp_frame:
            wrapper_info = WarpFrameWrapperCompetitveRl.__name__ + '\n'
        if frame_stack:
            wrapper_info = FrameStackWrapperCompetitiveRl.__name__ + '\n'
        return wrapper_info


class WarpFrameWrapperCompetitveRl(gym.ObservationWrapper):
    """Warp frames to 84x84 as done in the Nature paper and later work.

    :param gym.Env env: the environment to wrap.
    """

    def __init__(self, env, builtin_wrap):
        super().__init__(env)
        self.size = 84
        obs_space = env.observation_space
        self.builtin_wrap = builtin_wrap
        if builtin_wrap:
            # single player
            self.observation_space = gym.spaces.Box(
                low=np.min(obs_space.low),
                high=np.max(obs_space.high),
                shape=(self.size, self.size),
                dtype=obs_space.dtype
            )
        else:
            # double player
            self.observation_space = gym.spaces.tuple.Tuple(
                [
                    gym.spaces.Box(
                        low=np.min(obs_space[0].low),
                        high=np.max(obs_space[0].high),
                        shape=(self.size, self.size),
                        dtype=obs_space[0].dtype
                    ) for _ in range(len(obs_space))
                ]
            )

    def observation(self, frame):
        """returns the current observation from a frame"""
        if self.builtin_wrap:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
            return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
        else:
            frames = []
            for one_frame in frame:
                one_frame = cv2.cvtColor(one_frame, cv2.COLOR_RGB2GRAY)
                one_frame = cv2.resize(one_frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
                frames.append(one_frame)
            return frames


class FrameStackWrapperCompetitiveRl(gym.Wrapper):
    """Stack n_frames last frames.

    :param gym.Env env: the environment to wrap.
    :param int n_frames: the number of frames to stack.
    """

    def __init__(self, env, n_frames, builtin_wrap):
        super().__init__(env)
        self.n_frames = n_frames

        self.builtin_wrap = builtin_wrap
        obs_space = env.observation_space
        if self.builtin_wrap:
            self.frames = deque([], maxlen=n_frames)
            shape = (n_frames, ) + obs_space.shape
            self.observation_space = gym.spaces.Box(
                low=np.min(obs_space.low), high=np.max(obs_space.high), shape=shape, dtype=obs_space.dtype
            )
        else:
            self.frames = [deque([], maxlen=n_frames) for _ in range(len(obs_space))]
            shape = (n_frames, ) + obs_space[0].shape
            self.observation_space = gym.spaces.tuple.Tuple(
                [
                    gym.spaces.Box(
                        low=np.min(obs_space[0].low),
                        high=np.max(obs_space[0].high),
                        shape=shape,
                        dtype=obs_space[0].dtype
                    ) for _ in range(len(obs_space))
                ]
            )

    def reset(self):
        if self.builtin_wrap:
            obs = self.env.reset()
            for _ in range(self.n_frames):
                self.frames.append(obs)
            return self._get_ob(self.frames)
        else:
            obs = self.env.reset()
            for i, one_obs in enumerate(obs):
                for _ in range(self.n_frames):
                    self.frames[i].append(one_obs)
            return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))])

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if self.builtin_wrap:
            self.frames.append(obs)
            return self._get_ob(self.frames), reward, done, info
        else:
            for i, one_obs in enumerate(obs):
                self.frames[i].append(one_obs)
            return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))], axis=0), reward, done, info

    @staticmethod
    def _get_ob(frames):
        # the original wrapper use `LazyFrames` but since we use np buffer,
        # it has no effect
        return np.stack(frames, axis=0)