File size: 2,648 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import random
import numpy as np

from dizoo.gfootball.envs.obs.gfootball_obs import PlayerObs, MatchObs
from ding.utils.data import default_collate


def generate_data(player_obs: dict) -> np.array:
    dim = player_obs['dim']
    min = player_obs['value']['min']
    max = player_obs['value']['max']
    dinfo = player_obs['value']['dinfo']
    if dinfo in ['one-hot', 'boolean vector']:
        data = np.zeros((dim, ), dtype=np.float32)
        data[random.randint(0, dim - 1)] = 1
        return data
    elif dinfo == 'float':
        data = np.random.rand(dim)
        for dim_idx in range(dim):
            data[dim_idx] = min[dim_idx] + (max[dim_idx] - min[dim_idx]) * data[dim_idx]
        return data


class FakeGfootballDataset:

    def __init__(self):
        match_obs = MatchObs({})
        player_obs = PlayerObs({})
        self.match_obs_info = match_obs.template
        self.player_obs_info = player_obs.template
        self.action_dim = 19
        self.batch_size = 4
        del match_obs, player_obs

    def __len__(self) -> int:
        return self.batch_size

    def get_random_action(self) -> np.array:
        return np.random.randint(0, self.action_dim - 1, size=(1, ))

    def get_random_obs(self) -> dict:
        inputs = {}
        for match_obs in self.match_obs_info:
            key = match_obs['ret_key']
            data = generate_data(match_obs)
            inputs[key] = data
        players_list = []
        for _ in range(22):
            one_player = {}
            for player_obs in self.player_obs_info:
                key = player_obs['ret_key']
                data = generate_data(player_obs)
                one_player[key] = data
            players_list.append(one_player)
        inputs['players'] = players_list
        return inputs

    def get_batched_obs(self, bs: int) -> dict:
        batch = []
        for _ in range(bs):
            batch.append(self.get_random_obs())
        return default_collate(batch)

    def get_random_reward(self) -> np.array:
        return np.array([random.random() - 0.5])

    def get_random_terminals(self) -> int:
        sample = random.random()
        if sample > 0.99:
            return 1
        return 0

    def get_batch_sample(self, bs: int) -> list:
        batch = []
        for _ in range(bs):
            step = {}
            step['obs'] = self.get_random_obs()
            step['next_obs'] = self.get_random_obs()
            step['action'] = self.get_random_action()
            step['done'] = self.get_random_terminals()
            step['reward'] = self.get_random_reward()
            batch.append(step)
        return batch