File size: 7,620 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os
import random
import pytest
import copy
from easydict import EasyDict
import torch
from ding.league import create_league
one_vs_one_league_default_config = dict(
league=dict(
league_type='one_vs_one',
import_names=["ding.league"],
# ---player----
# "player_category" is just a name. Depends on the env.
# For example, in StarCraft, this can be ['zerg', 'terran', 'protoss'].
player_category=['default'],
# Support different types of active players for solo and battle league.
# For solo league, supports ['solo_active_player'].
# For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter'].
active_players=dict(
naive_sp_player=1, # {player_type: player_num}
),
naive_sp_player=dict(
# There should be keys ['one_phase_step', 'branch_probs', 'strong_win_rate'].
# Specifically for 'main_exploiter' of StarCraft, there should be an additional key ['min_valid_win_rate'].
one_phase_step=10,
branch_probs=dict(
pfsp=0.5,
sp=0.5,
),
strong_win_rate=0.7,
),
# "use_pretrain" means whether to use pretrain model to initialize active player.
use_pretrain=False,
# "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player.
# "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and
# "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well.
# Otherwise, "pretrain_checkpoint_path" should list paths of all player categories.
use_pretrain_init_historical=False,
pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ),
# ---payoff---
payoff=dict(
# Supports ['battle']
type='battle',
decay=0.99,
min_win_rate_games=8,
),
path_policy='./league',
),
)
one_vs_one_league_default_config = EasyDict(one_vs_one_league_default_config)
def get_random_result():
ran = random.random()
if ran < 1. / 3:
return "wins"
elif ran < 1. / 2:
return "losses"
else:
return "draws"
@pytest.mark.unittest
class TestOneVsOneLeague:
def test_naive(self):
league = create_league(one_vs_one_league_default_config.league)
assert (len(league.active_players) == 1)
assert (len(league.historical_players) == 0)
active_player_ids = [p.player_id for p in league.active_players]
assert set(active_player_ids) == set(league.active_players_ids)
active_player_id = active_player_ids[0]
active_player_ckpt = league.active_players[0].checkpoint_path
tmp = torch.tensor([1, 2, 3])
path_policy = one_vs_one_league_default_config.league.path_policy
torch.save(tmp, active_player_ckpt)
# judge_snapshot & update_active_player
assert not league.judge_snapshot(active_player_id)
player_update_dict = {
'player_id': active_player_id,
'train_iteration': one_vs_one_league_default_config.league.naive_sp_player.one_phase_step * 2,
}
league.update_active_player(player_update_dict)
assert league.judge_snapshot(active_player_id)
historical_player_ids = [p.player_id for p in league.historical_players]
assert len(historical_player_ids) == 1
historical_player_id = historical_player_ids[0]
# get_job_info, eval_flag=False
vs_active = False
vs_historical = False
while True:
collect_job_info = league.get_job_info(active_player_id, eval_flag=False)
assert collect_job_info['agent_num'] == 2
assert len(collect_job_info['checkpoint_path']) == 2
assert collect_job_info['launch_player'] == active_player_id
assert collect_job_info['player_id'][0] == active_player_id
if collect_job_info['player_active_flag'][1]:
assert collect_job_info['player_id'][1] == collect_job_info['player_id'][0]
vs_active = True
else:
assert collect_job_info['player_id'][1] == historical_player_id
vs_historical = True
if vs_active and vs_historical:
break
# get_job_info, eval_flag=False
eval_job_info = league.get_job_info(active_player_id, eval_flag=True)
assert eval_job_info['agent_num'] == 1
assert len(eval_job_info['checkpoint_path']) == 1
assert eval_job_info['launch_player'] == active_player_id
assert eval_job_info['player_id'][0] == active_player_id
assert len(eval_job_info['player_id']) == 1
assert len(eval_job_info['player_active_flag']) == 1
assert eval_job_info['eval_opponent'] in league.active_players[0]._eval_opponent_difficulty
# finish_job
episode_num = 5
env_num = 8
player_id = [active_player_id, historical_player_id]
result = [[get_random_result() for __ in range(8)] for _ in range(5)]
payoff_update_info = {
'launch_player': active_player_id,
'player_id': player_id,
'episode_num': episode_num,
'env_num': env_num,
'result': result,
}
league.finish_job(payoff_update_info)
wins = 0
games = episode_num * env_num
for i in result:
for j in i:
if j == 'wins':
wins += 1
league.payoff[league.active_players[0], league.historical_players[0]] == wins / games
os.popen("rm -rf {}".format(path_policy))
print("Finish!")
def test_league_info(self):
cfg = copy.deepcopy(one_vs_one_league_default_config.league)
cfg.path_policy = 'test_league_info'
league = create_league(cfg)
active_player_id = [p.player_id for p in league.active_players][0]
active_player_ckpt = [p.checkpoint_path for p in league.active_players][0]
tmp = torch.tensor([1, 2, 3])
torch.save(tmp, active_player_ckpt)
assert (len(league.active_players) == 1)
assert (len(league.historical_players) == 0)
print('\n')
print(repr(league.payoff))
print(league.player_rank(string=True))
league.judge_snapshot(active_player_id, force=True)
for i in range(10):
job = league.get_job_info(active_player_id, eval_flag=False)
payoff_update_info = {
'launch_player': active_player_id,
'player_id': job['player_id'],
'episode_num': 2,
'env_num': 4,
'result': [[get_random_result() for __ in range(4)] for _ in range(2)]
}
league.finish_job(payoff_update_info)
# if not self-play
if job['player_id'][0] != job['player_id'][1]:
win_loss_result = sum(payoff_update_info['result'], [])
home = league.get_player_by_id(job['player_id'][0])
away = league.get_player_by_id(job['player_id'][1])
home.rating, away.rating = league.metric_env.rate_1vs1(home.rating, away.rating, win_loss_result)
print(repr(league.payoff))
print(league.player_rank(string=True))
os.popen("rm -rf {}".format(cfg.path_policy))
if __name__ == '__main__':
pytest.main(["-sv", os.path.basename(__file__)])
|