|
import os |
|
|
|
import numpy as np |
|
import pytest |
|
from easydict import EasyDict |
|
|
|
from ding.league.player import Player, HistoricalPlayer, ActivePlayer, create_player |
|
from ding.league.shared_payoff import create_payoff |
|
from ding.league.starcraft_player import MainPlayer, MainExploiter, LeagueExploiter |
|
from ding.league.tests.league_test_default_config import league_test_config |
|
from ding.league.metric import LeagueMetricEnv |
|
|
|
ONE_PHASE_STEP = 2000 |
|
env = LeagueMetricEnv() |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_payoff(): |
|
cfg = EasyDict({'type': 'battle', 'decay': 0.99}) |
|
return create_payoff(cfg) |
|
|
|
|
|
@pytest.fixture(scope='function') |
|
def setup_league(setup_payoff): |
|
players = [] |
|
for category in ['zerg', 'terran', 'protoss']: |
|
|
|
main_player_name = '{}_{}'.format('MainPlayer', category) |
|
players.append( |
|
create_player( |
|
league_test_config.league, 'main_player', league_test_config.league.main_player, category, setup_payoff, |
|
'ckpt_{}.pth'.format(main_player_name), main_player_name, 0, env.create_rating() |
|
) |
|
) |
|
|
|
main_exploiter_name = '{}_{}'.format('MainExploiter', category) |
|
players.append( |
|
create_player( |
|
league_test_config.league, 'main_exploiter', league_test_config.league.main_exploiter, category, |
|
setup_payoff, 'ckpt_{}.pth'.format(main_exploiter_name), main_exploiter_name, 0, env.create_rating() |
|
) |
|
) |
|
|
|
league_exploiter_name = '{}_{}'.format('LeagueExploiter', category) |
|
for i in range(2): |
|
players.append( |
|
create_player( |
|
league_test_config.league, |
|
'league_exploiter', |
|
league_test_config.league.league_exploiter, |
|
category, |
|
setup_payoff, |
|
'ckpt_{}.pth'.format(league_exploiter_name), |
|
league_exploiter_name, |
|
0, |
|
env.create_rating(), |
|
) |
|
) |
|
|
|
sl_hp_name = '{}_{}_sl'.format('MainPlayer', category) |
|
players.append( |
|
create_player( |
|
league_test_config.league, |
|
'historical_player', |
|
EasyDict(), |
|
category, |
|
setup_payoff, |
|
'ckpt_sl_{}'.format(sl_hp_name), |
|
sl_hp_name, |
|
0, |
|
env.create_rating(), |
|
parent_id=main_player_name, |
|
) |
|
) |
|
for p in players: |
|
setup_payoff.add_player(p) |
|
return players |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestMainPlayer: |
|
|
|
def test_get_job(self, setup_league, setup_payoff): |
|
N = 10 |
|
|
|
|
|
for p in setup_league: |
|
if isinstance(p, MainPlayer): |
|
for i in range(N): |
|
job_dict = p.get_job() |
|
assert isinstance(job_dict, dict) |
|
opponent = job_dict['opponent'] |
|
assert isinstance(opponent, Player) |
|
assert opponent in setup_league |
|
|
|
|
|
hp_list = [] |
|
for p in setup_league: |
|
if isinstance(p, ActivePlayer): |
|
p.total_agent_step = 2 * ONE_PHASE_STEP |
|
hp = p.snapshot(env) |
|
hp_list.append(hp) |
|
setup_payoff.add_player(hp) |
|
setup_league += hp_list |
|
|
|
|
|
pfsp, sp, veri = False, False, False |
|
for p in setup_league: |
|
if isinstance(p, MainPlayer): |
|
while True: |
|
job_dict = p.get_job() |
|
opponent = job_dict['opponent'] |
|
if isinstance(opponent, HistoricalPlayer) and 'MainPlayer' in opponent.parent_id: |
|
veri = True |
|
elif isinstance(opponent, HistoricalPlayer): |
|
pfsp = True |
|
elif isinstance(opponent, MainPlayer): |
|
sp = True |
|
else: |
|
raise Exception("Main Player selects a wrong opponent {}", type(opponent)) |
|
if veri and pfsp and sp: |
|
break |
|
|
|
def test_snapshot(self, setup_league, setup_payoff): |
|
N = 10 |
|
for p in setup_league: |
|
for i in range(N): |
|
if isinstance(p, ActivePlayer): |
|
hp = p.snapshot(env) |
|
assert isinstance(hp, HistoricalPlayer) |
|
assert id(hp.payoff) == id(p.payoff) |
|
assert hp.parent_id == p.player_id |
|
|
|
def test_is_trained_enough(self, setup_league, setup_payoff): |
|
for p in setup_league: |
|
if isinstance(p, ActivePlayer): |
|
assert not p.is_trained_enough() |
|
assert p._last_enough_step == 0 |
|
|
|
p.total_agent_step = ONE_PHASE_STEP * 0.99 |
|
assert not p.is_trained_enough() |
|
assert p._last_enough_step == 0 |
|
|
|
p.total_agent_step = ONE_PHASE_STEP + 1 |
|
assert not p.is_trained_enough() |
|
assert p._last_enough_step == 0 |
|
|
|
|
|
|
|
hp_list = [] |
|
for p in setup_league: |
|
if isinstance(p, MainPlayer): |
|
hp = p.snapshot(env) |
|
setup_payoff.add_player(hp) |
|
hp_list.append(hp) |
|
setup_league += hp_list |
|
|
|
|
|
N = 10 |
|
assert isinstance(setup_league[0], MainPlayer) |
|
for n in range(N): |
|
for hp in [p for p in setup_league if isinstance(p, HistoricalPlayer)]: |
|
match_info = { |
|
'player_id': [setup_league[0].player_id, hp.player_id], |
|
'result': [['wins']], |
|
} |
|
result = setup_payoff.update(match_info) |
|
assert result |
|
assert setup_league[0]._total_agent_step > ONE_PHASE_STEP |
|
assert setup_league[0]._last_enough_step == 0 |
|
assert setup_league[0]._last_enough_step != setup_league[0]._total_agent_step |
|
assert setup_league[0].is_trained_enough() |
|
assert setup_league[0]._last_enough_step == setup_league[0]._total_agent_step |
|
|
|
|
|
|
|
assert isinstance(setup_league[5], MainPlayer) |
|
for n in range(N): |
|
for hp in hp_list: |
|
match_info = { |
|
'player_id': [setup_league[5].player_id, hp.player_id], |
|
'result': [['draws']], |
|
} |
|
result = setup_payoff.update(match_info) |
|
assert result |
|
assert setup_league[5]._total_agent_step > ONE_PHASE_STEP |
|
assert not setup_league[5].is_trained_enough() |
|
setup_league[5].total_agent_step = 2 * ONE_PHASE_STEP |
|
assert setup_league[5].is_trained_enough() |
|
|
|
def test_mutate(self, setup_league, setup_payoff): |
|
|
|
assert isinstance(setup_league[0], MainPlayer) |
|
for _ in range(10): |
|
assert setup_league[0].mutate({}) is None |
|
|
|
def test_sp_historical(self, setup_league, setup_payoff): |
|
N = 10 |
|
main1 = setup_league[0] |
|
main2 = setup_league[5] |
|
assert isinstance(main1, MainPlayer) |
|
assert isinstance(main2, MainPlayer) |
|
for n in range(N): |
|
match_info = { |
|
'player_id': [main1.player_id, main2.player_id], |
|
'result': [['wins']], |
|
} |
|
result = setup_payoff.update(match_info) |
|
assert result |
|
for _ in range(200): |
|
opponent = main2._sp_branch() |
|
condition1 = opponent.category == 'terran' or opponent.category == 'protoss' |
|
|
|
condition2 = opponent.category == 'zerg' and isinstance( |
|
opponent, HistoricalPlayer |
|
) and opponent.parent_id == main1.player_id |
|
assert condition1 or condition2, (condition1, condition2) |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestMainExploiter: |
|
|
|
def test_get_job(self, setup_league, random_job_result, setup_payoff): |
|
assert isinstance(setup_league[1], MainExploiter) |
|
job_dict = setup_league[1].get_job() |
|
opponent = job_dict['opponent'] |
|
assert isinstance(opponent, MainPlayer) |
|
|
|
N = 10 |
|
|
|
for n in range(N): |
|
for p in setup_league: |
|
if isinstance(p, MainPlayer): |
|
match_info = { |
|
'player_id': [setup_league[1].player_id, p.player_id], |
|
'result': [['losses']], |
|
} |
|
assert setup_payoff.update(match_info) |
|
|
|
job_dict = setup_league[1].get_job() |
|
opponent = job_dict['opponent'] |
|
|
|
assert (isinstance(opponent, HistoricalPlayer) |
|
and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer) |
|
hp_list = [] |
|
for i in range(3): |
|
for p in setup_league: |
|
if isinstance(p, MainPlayer): |
|
p.total_agent_step = (i + 1) * 2 * ONE_PHASE_STEP |
|
hp = p.snapshot(env) |
|
setup_payoff.add_player(hp) |
|
hp_list.append(hp) |
|
setup_league += hp_list |
|
|
|
no_main_player_league = [p for p in setup_league if not isinstance(p, MainPlayer)] |
|
for i in range(10000): |
|
home = np.random.choice(no_main_player_league) |
|
away = np.random.choice(no_main_player_league) |
|
result = random_job_result() |
|
match_info = { |
|
'player_id': [home.player_id, away.player_id], |
|
'result': [[result]], |
|
} |
|
assert setup_payoff.update(match_info) |
|
|
|
for i in range(10): |
|
job_dict = setup_league[1].get_job() |
|
opponent = job_dict['opponent'] |
|
|
|
assert (isinstance(opponent, HistoricalPlayer) |
|
and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer) |
|
|
|
def test_is_trained_enough(self, setup_league): |
|
|
|
pass |
|
|
|
def test_mutate(self, setup_league): |
|
assert isinstance(setup_league[1], MainExploiter) |
|
info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'} |
|
for _ in range(10): |
|
assert setup_league[1].mutate(info) == info['reset_checkpoint_path'] |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestLeagueExploiter: |
|
|
|
def test_get_job(self, setup_league): |
|
assert isinstance(setup_league[2], LeagueExploiter) |
|
job_dict = setup_league[2].get_job() |
|
opponent = job_dict['opponent'] |
|
assert isinstance(opponent, HistoricalPlayer) |
|
assert isinstance(setup_league[3], LeagueExploiter) |
|
job_dict = setup_league[3].get_job() |
|
opponent = job_dict['opponent'] |
|
assert isinstance(opponent, HistoricalPlayer) |
|
|
|
def test_is_trained_enough(self, setup_league): |
|
|
|
pass |
|
|
|
def test_mutate(self, setup_league): |
|
assert isinstance(setup_league[2], LeagueExploiter) |
|
info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'} |
|
results = [] |
|
for _ in range(1000): |
|
results.append(setup_league[2].mutate(info)) |
|
freq = len([t for t in results if t]) * 1.0 / len(results) |
|
assert 0.2 <= freq <= 0.3 |
|
|
|
|
|
if __name__ == '__main__': |
|
pytest.main(["-sv", os.path.basename(__file__)]) |
|
|