gomoku / DI-engine /ding /league /starcraft_player.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
9.52 kB
from typing import Optional, Union
import numpy as np
from ding.utils import PLAYER_REGISTRY
from .player import ActivePlayer, HistoricalPlayer
from .algorithm import pfsp
@PLAYER_REGISTRY.register('main_player')
class MainPlayer(ActivePlayer):
"""
Overview:
Main player in league training.
Default branch (0.5 pfsp, 0.35 sp, 0.15 veri).
Default snapshot every 2e9 steps.
Default mutate prob = 0 (never mutate).
Interface:
__init__, is_trained_enough, snapshot, mutate, get_job
Property:
race, payoff, checkpoint_path, player_id, train_iteration
"""
_name = "MainPlayer"
def _pfsp_branch(self) -> HistoricalPlayer:
"""
Overview:
Select prioritized fictitious self-play opponent, should be a historical player.
Returns:
- player (:obj:`HistoricalPlayer`): the selected historical player
"""
historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer))
win_rates = self._payoff[self, historical]
p = pfsp(win_rates, weighting='squared')
return self._get_opponent(historical, p)
def _sp_branch(self):
"""
Overview:
Select normal self-play opponent
"""
main_players = self._get_players(lambda p: isinstance(p, MainPlayer))
main_opponent = self._get_opponent(main_players)
# TODO(nyz) if only one main_player, self-play win_rates are constantly equal to 0.5
# main_opponent is not too strong
if self._payoff[self, main_opponent] > 1 - self._strong_win_rate:
return main_opponent
# if the main_opponent is too strong, select a past alternative
historical = self._get_players(
lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id
)
win_rates = self._payoff[self, historical]
p = pfsp(win_rates, weighting='variance')
return self._get_opponent(historical, p)
def _verification_branch(self):
"""
Overview:
Verify no strong historical main exploiter and no forgotten historical past main player
"""
# check exploitation
main_exploiters = self._get_players(lambda p: isinstance(p, MainExploiter))
exp_historical = self._get_players(
lambda p: isinstance(p, HistoricalPlayer) and any([p.parent_id == m.player_id for m in main_exploiters])
)
win_rates = self._payoff[self, exp_historical]
# TODO(nyz) why min win_rates 0.3
if len(win_rates) and win_rates.min() < 1 - self._strong_win_rate:
p = pfsp(win_rates, weighting='squared')
return self._get_opponent(exp_historical, p)
# check forgotten
main_players = self._get_players(lambda p: isinstance(p, MainPlayer))
main_opponent = self._get_opponent(main_players) # only one main player
main_historical = self._get_players(
lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id
)
win_rates = self._payoff[self, main_historical]
# TODO(nyz) whether the method `_get_players` should return players with some sequence(such as step)
# win_rates, historical = self._remove_monotonic_suffix(win_rates, historical)
if len(win_rates) and win_rates.min() < self._strong_win_rate:
p = pfsp(win_rates, weighting='squared')
return self._get_opponent(main_historical, p)
# no forgotten main players or strong main exploiters, use self-play instead
return self._sp_branch()
# def _remove_monotonic_suffix(self, win_rates, players):
# if not len(win_rates):
# return win_rates, players
# for i in range(len(win_rates) - 1, 0, -1):
# if win_rates[i - 1] < win_rates[i]:
# return win_rates[:i + 1], players[:i + 1]
# return np.array([]), []
# override
def is_trained_enough(self) -> bool:
# ``_pfsp_branch`` and ``_verification_branch`` are played against historcial player
return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer))
# override
def mutate(self, info: dict) -> None:
"""
Overview:
MainPlayer does not mutate
"""
pass
@PLAYER_REGISTRY.register('main_exploiter')
class MainExploiter(ActivePlayer):
"""
Overview:
Main exploiter in league training. Can identify weaknesses of main agents, and consequently make them
more robust.
Default branch (1.0 main_players).
Default snapshot when defeating all 3 main players in the league in more than 70% of games,
or timeout of 4e9 steps.
Default mutate prob = 1 (must mutate).
Interface:
__init__, is_trained_enough, snapshot, mutate, get_job
Property:
race, payoff, checkpoint_path, player_id, train_iteration
"""
_name = "MainExploiter"
def __init__(self, *args, **kwargs):
"""
Overview:
Initialize ``min_valid_win_rate`` additionally
Note:
- min_valid_win_rate (:obj:`float`): only when win rate against the main player is greater than this, \
can the main player be regarded as able to produce valid training signals to be selected
"""
super(MainExploiter, self).__init__(*args, **kwargs)
self._min_valid_win_rate = self._cfg.min_valid_win_rate
def _main_players_branch(self):
"""
Overview:
Select main player or historical player snapshot from main player as opponent
Returns:
- player (:obj:`Player`): the selected main player (active/historical)
"""
# get the main player (only one)
main_players = self._get_players(lambda p: isinstance(p, MainPlayer))
main_opponent = self._get_opponent(main_players)
# if this main_opponent can produce valid training signals
if self._payoff[self, main_opponent] >= self._min_valid_win_rate:
return main_opponent
# otherwise, curriculum learning, select a historical version
historical = self._get_players(
lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id
)
win_rates = self._payoff[self, historical]
p = pfsp(win_rates, weighting='variance')
return self._get_opponent(historical, p)
# override
def is_trained_enough(self):
# would play against main player, or historical main player (if main player is too strong)
return super().is_trained_enough(select_fn=lambda p: isinstance(p, MainPlayer))
# override
def mutate(self, info: dict) -> str:
"""
Overview:
Main exploiter is sure to mutate(reset) to the supervised learning player
Returns:
- mutate_ckpt_path (:obj:`str`): mutation target checkpoint path
"""
return info['reset_checkpoint_path']
@PLAYER_REGISTRY.register('league_exploiter')
class LeagueExploiter(ActivePlayer):
"""
Overview:
League exploiter in league training. Can identify global blind spots in the league (strategies that no player
in the league can beat, but that are not necessarily robust themselves).
Default branch (1.0 pfsp).
Default snapshot when defeating all players in the league in more than 70% of games, or timeout of 2e9 steps.
Default mutate prob = 0.25.
Interface:
__init__, is_trained_enough, snapshot, mutate, get_job
Property:
race, payoff, checkpoint_path, player_id, train_iteration
"""
_name = "LeagueExploiter"
def __init__(self, *args, **kwargs) -> None:
"""
Overview:
Initialize ``mutate_prob`` additionally
Note:
- mutate_prob (:obj:`float`): the mutation probability of league exploiter. should be in [0, 1]
"""
super(LeagueExploiter, self).__init__(*args, **kwargs)
assert 0 <= self._cfg.mutate_prob <= 1
self.mutate_prob = self._cfg.mutate_prob
def _pfsp_branch(self) -> HistoricalPlayer:
"""
Overview:
Select prioritized fictitious self-play opponent
Returns:
- player (:obj:`HistoricalPlayer`): the selected historical player
Note:
This branch is the same as the psfp branch in MainPlayer
"""
historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer))
win_rates = self._payoff[self, historical]
p = pfsp(win_rates, weighting='squared')
return self._get_opponent(historical, p)
# override
def is_trained_enough(self) -> bool:
# will only player against historical player
return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer))
# override
def mutate(self, info) -> Union[str, None]:
"""
Overview:
League exploiter can mutate to the supervised learning player with 0.25 prob
Returns:
- ckpt_path (:obj:`Union[str, None]`): with ``mutate_prob`` prob returns the pretrained model's ckpt path, \
with left 1 - ``mutate_prob`` prob returns None, which means no mutation
"""
p = np.random.uniform()
if p < self.mutate_prob:
return info['reset_checkpoint_path']
return None