File size: 13,828 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
from typing import Callable, Optional, List
from collections import namedtuple
import numpy as np
from easydict import EasyDict

from ding.utils import import_module, PLAYER_REGISTRY
from .algorithm import pfsp


class Player:
    """
    Overview:
        Base player class, player is the basic member of a league
    Interfaces:
        __init__
    Property:
        race, payoff, checkpoint_path, player_id, total_agent_step
    """
    _name = "BasePlayer"  # override this variable for sub-class player

    def __init__(
            self,
            cfg: EasyDict,
            category: str,
            init_payoff: 'BattleSharedPayoff',  # noqa
            checkpoint_path: str,
            player_id: str,
            total_agent_step: int,
            rating: 'PlayerRating',  # noqa
    ) -> None:
        """
        Overview:
            Initialize base player metadata
        Arguments:
            - cfg (:obj:`EasyDict`): Player config dict.
            - category (:obj:`str`): Player category, depending on the game, \
                e.g. StarCraft has 3 races ['terran', 'protoss', 'zerg'].
            - init_payoff (:obj:`Union[BattleSharedPayoff, SoloSharedPayoff]`): Payoff shared by all players.
            - checkpoint_path (:obj:`str`): The path to load player checkpoint.
            - player_id (:obj:`str`): Player id in string format.
            - total_agent_step (:obj:`int`): For active player, it should be 0; \
                For historical player, it should be parent player's ``_total_agent_step`` when ``snapshot``.
            - rating (:obj:`PlayerRating`): player rating information in total league
        """
        self._cfg = cfg
        self._category = category
        self._payoff = init_payoff
        self._checkpoint_path = checkpoint_path
        assert isinstance(player_id, str)
        self._player_id = player_id
        assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step))
        self._total_agent_step = total_agent_step
        self._rating = rating

    @property
    def category(self) -> str:
        return self._category

    @property
    def payoff(self) -> 'BattleSharedPayoff':  # noqa
        return self._payoff

    @property
    def checkpoint_path(self) -> str:
        return self._checkpoint_path

    @property
    def player_id(self) -> str:
        return self._player_id

    @property
    def total_agent_step(self) -> int:
        return self._total_agent_step

    @total_agent_step.setter
    def total_agent_step(self, step: int) -> None:
        self._total_agent_step = step

    @property
    def rating(self) -> 'PlayerRating':  # noqa
        return self._rating

    @rating.setter
    def rating(self, _rating: 'PlayerRating') -> None:  # noqa
        self._rating = _rating


@PLAYER_REGISTRY.register('historical_player')
class HistoricalPlayer(Player):
    """
    Overview:
        Historical player which is snapshotted from an active player, and is fixed with the checkpoint.
        Have a unique attribute ``parent_id``.
    Property:
        race, payoff, checkpoint_path, player_id, total_agent_step, parent_id
    """
    _name = "HistoricalPlayer"

    def __init__(self, *args, parent_id: str) -> None:
        """
        Overview:
            Initialize ``_parent_id`` additionally
        Arguments:
            - parent_id (:obj:`str`): id of historical player's parent, should be an active player
        """
        super().__init__(*args)
        self._parent_id = parent_id

    @property
    def parent_id(self) -> str:
        return self._parent_id


class ActivePlayer(Player):
    """
    Overview:
        Active player can be updated, or snapshotted to a historical player in the league training.
    Interface:
        __init__, is_trained_enough, snapshot, mutate, get_job
    Property:
        race, payoff, checkpoint_path, player_id, total_agent_step
    """
    _name = "ActivePlayer"
    BRANCH = namedtuple("BRANCH", ['name', 'prob'])

    def __init__(self, *args, **kwargs) -> None:
        """
        Overview:
            Initialize player metadata, depending on the game
        Note:
            - one_phase_step (:obj:`int`): An active player will be considered trained enough for snapshot \
                after two phase steps.
            - last_enough_step (:obj:`int`): Player's last step number that satisfies ``_is_trained_enough``.
            - strong_win_rate (:obj:`float`): If win rates between this player and all the opponents are greater than
                this value, this player can be regarded as strong enough to these opponents. \
                If also already trained for one phase step, this player can be regarded as trained enough for snapshot.
            - branch_probs (:obj:`namedtuple`): A namedtuple of probabilities of selecting different opponent branch.
        """
        super().__init__(*args)
        self._one_phase_step = int(float(self._cfg.one_phase_step))  # ``one_phase_step`` is like 1e9
        self._last_enough_step = 0
        self._strong_win_rate = self._cfg.strong_win_rate
        assert isinstance(self._cfg.branch_probs, dict)
        self._branch_probs = [self.BRANCH(k, v) for k, v in self._cfg.branch_probs.items()]
        # self._eval_opponent_difficulty = ["WEAK", "MEDIUM", "STRONG"]
        self._eval_opponent_difficulty = ["RULE_BASED"]
        self._eval_opponent_index = 0

    def is_trained_enough(self, select_fn: Optional[Callable] = None) -> bool:
        """
        Overview:
            Judge whether this player is trained enough for further operations(e.g. snapshot, mutate...)
            according to past step count and overall win rates against opponents.
            If yes, set ``self._last_agent_step`` to ``self._total_agent_step`` and return True; otherwise return False.
        Arguments:
            - select_fn (:obj:`function`): The function to select opponent players.
        Returns:
            - flag (:obj:`bool`): Whether this player is trained enough
        """
        if select_fn is None:
            select_fn = lambda x: isinstance(x, HistoricalPlayer)  # noqa
        step_passed = self._total_agent_step - self._last_enough_step
        if step_passed < self._one_phase_step:
            return False
        elif step_passed >= 2 * self._one_phase_step:
            # ``step_passed`` is 2 times of ``self._one_phase_step``, regarded as trained enough
            self._last_enough_step = self._total_agent_step
            return True
        else:
            # Get payoff against specific opponents (Different players have different type of opponent players)
            # If min win rate is larger than ``self._strong_win_rate``, then is judged trained enough
            selected_players = self._get_players(select_fn)
            if len(selected_players) == 0:  # No such player, therefore no past game
                return False
            win_rates = self._payoff[self, selected_players]
            if win_rates.min() > self._strong_win_rate:
                self._last_enough_step = self._total_agent_step
                return True
            else:
                return False

    def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer:  # noqa
        """
        Overview:
            Generate a snapshot historical player from the current player, called in league's ``_snapshot``.
        Argument:
            - metric_env (:obj:`LeagueMetricEnv`): player rating environment, one league one env
        Returns:
            - snapshot_player (:obj:`HistoricalPlayer`): new instantiated historical player

        .. note::
            This method only generates a historical player object, but without saving the checkpoint, which should be
            done by league.
        """
        path = self.checkpoint_path.split('.pth')[0] + '_{}'.format(self._total_agent_step) + '.pth'
        return HistoricalPlayer(
            self._cfg,
            self.category,
            self.payoff,
            path,
            self.player_id + '_{}_historical'.format(int(self._total_agent_step)),
            self._total_agent_step,
            metric_env.create_rating(mu=self.rating.mu),
            parent_id=self.player_id
        )

    def mutate(self, info: dict) -> Optional[str]:
        """
        Overview:
            Mutate the current player, called in league's ``_mutate_player``.
        Arguments:
            - info (:obj:`dict`): related information for the mutation
        Returns:
            - mutation_result (:obj:`str`): if the player does the mutation operation then returns the
                corresponding model path, otherwise returns None
        """
        pass

    def get_job(self, eval_flag: bool = False) -> dict:
        """
        Overview:
            Get a dict containing some info about the job to be launched, e.g. the selected opponent.
        Arguments:
            - eval_flag (:obj:`bool`): Whether to select an opponent for evaluator task.
        Returns:
            - ret (:obj:`dict`): The returned dict. Should contain key ['opponent'].
        """
        if eval_flag:
            # eval opponent is a str.
            opponent = self._eval_opponent_difficulty[self._eval_opponent_index]
        else:
            # collect opponent is a Player.
            opponent = self._get_collect_opponent()
        return {
            'opponent': opponent,
        }

    def _get_collect_opponent(self) -> Player:
        """
        Overview:
            Select an opponent according to the player's ``branch_probs``.
        Returns:
            - opponent (:obj:`Player`): Selected opponent.
        """
        p = np.random.uniform()
        L = len(self._branch_probs)
        cum_p = [0.] + [sum([j.prob for j in self._branch_probs[:i + 1]]) for i in range(L)]
        idx = [cum_p[i] <= p < cum_p[i + 1] for i in range(L)].index(True)
        branch_name = '_{}_branch'.format(self._branch_probs[idx].name)
        opponent = getattr(self, branch_name)()
        return opponent

    def _get_players(self, select_fn: Callable) -> List[Player]:
        """
        Overview:
            Get a list of players in the league (shared_payoff), selected by ``select_fn`` .
        Arguments:
            - select_fn (:obj:`function`): players in the returned list must satisfy this function
        Returns:
            - players (:obj:`list`): a list of players that satisfies ``select_fn``
        """
        return [player for player in self._payoff.players if select_fn(player)]

    def _get_opponent(self, players: list, p: Optional[np.ndarray] = None) -> Player:
        """
        Overview:
            Get one opponent player from list ``players`` according to probability ``p``.
        Arguments:
            - players (:obj:`list`): a list of players that can select opponent from
            - p (:obj:`np.ndarray`): the selection probability of each player, should have the same size as \
                ``players``. If you don't need it and set None, it would select uniformly by default.
        Returns:
            - opponent_player (:obj:`Player`): a random chosen opponent player according to probability
        """
        idx = np.random.choice(len(players), p=p)
        return players[idx]

    def increment_eval_difficulty(self) -> bool:
        """
        Overview:
            When evaluating, active player will choose a specific builtin opponent difficulty.
            This method is used to increment the difficulty.
            It is usually called after the easier builtin bot is already been beaten by this player.
        Returns:
            - increment_or_not (:obj:`bool`): True means difficulty is incremented; \
                False means difficulty is already the hardest.
        """
        if self._eval_opponent_index < len(self._eval_opponent_difficulty) - 1:
            self._eval_opponent_index += 1
            return True
        else:
            return False

    @property
    def checkpoint_path(self) -> str:
        return self._checkpoint_path

    @checkpoint_path.setter
    def checkpoint_path(self, path: str) -> None:
        self._checkpoint_path = path


@PLAYER_REGISTRY.register('naive_sp_player')
class NaiveSpPlayer(ActivePlayer):

    def _pfsp_branch(self) -> HistoricalPlayer:
        """
        Overview:
            Select prioritized fictitious self-play opponent, should be a historical player.
        Returns:
            - player (:obj:`HistoricalPlayer`): The selected historical player.
        """
        historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer))
        win_rates = self._payoff[self, historical]
        # Normal self-play if no historical players
        if win_rates.shape == (0, ):
            return self
        p = pfsp(win_rates, weighting='squared')
        return self._get_opponent(historical, p)

    def _sp_branch(self) -> ActivePlayer:
        """
        Overview:
            Select normal self-play opponent
        """
        return self


def create_player(cfg: EasyDict, player_type: str, *args, **kwargs) -> Player:
    """
    Overview:
        Given the key (player_type), create a new player instance if in player_mapping's values,
        or raise an KeyError. In other words, a derived player must first register then call ``create_player``
        to get the instance object.
    Arguments:
        - cfg (:obj:`EasyDict`): player config, necessary keys: [import_names]
        - player_type (:obj:`str`): the type of player to be created
    Returns:
        - player (:obj:`Player`): the created new player, should be an instance of one of \
            player_mapping's values
    """
    import_module(cfg.get('import_names', []))
    return PLAYER_REGISTRY.build(player_type, *args, **kwargs)