File size: 1,805 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import shutil
from easydict import EasyDict
from ding.league import BaseLeague, ActivePlayer


class DemoLeague(BaseLeague):

    def __init__(self, cfg):
        super(DemoLeague, self).__init__(cfg)
        self.reset_checkpoint_path = os.path.join(self.path_policy, 'reset_ckpt.pth')

    # override
    def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
        assert isinstance(player, ActivePlayer), player.__class__
        player_job_info = EasyDict(player.get_job(eval_flag))
        return {
            'agent_num': 2,
            'launch_player': player.player_id,
            'player_id': [player.player_id, player_job_info.opponent.player_id],
            'checkpoint_path': [player.checkpoint_path, player_job_info.opponent.checkpoint_path],
            'player_active_flag': [isinstance(p, ActivePlayer) for p in [player, player_job_info.opponent]],
        }

    # override
    def _mutate_player(self, player: ActivePlayer):
        for p in self.active_players:
            result = p.mutate({'reset_checkpoint_path': self.reset_checkpoint_path})
            if result is not None:
                p.rating = self.metric_env.create_rating()
                self.load_checkpoint(p.player_id, result)  # load_checkpoint is set by the caller of league
                self.save_checkpoint(result, p.checkpoint_path)

    # override
    def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
        assert isinstance(player, ActivePlayer)
        if 'learner_step' in player_info:
            player.total_agent_step = player_info['learner_step']

    # override
    @staticmethod
    def save_checkpoint(src_checkpoint_path: str, dst_checkpoint_path: str) -> None:
        shutil.copy(src_checkpoint_path, dst_checkpoint_path)