File size: 1,884 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import logging
import math
from ding.torch_utils import to_list
from ding.utils.data import NaiveRLDataset
from torch.utils.data import DataLoader

logging.basicConfig(level=logging.INFO)


def test_accuracy_in_dataset(data_path, batch_size, policy):
    """
    Overview:
        Evaluate total accuracy and accuracy of each action in dataset from
        ``datapath`` using the ``policy`` for gfootball env.
    """
    dataset = NaiveRLDataset(data_path)
    dataloader = DataLoader(dataset, batch_size)

    total_accuracy_in_dataset = []
    action_accuracy_in_dataset = {k: [] for k in range(19)}
    for _, minibatch in enumerate(dataloader):
        policy_output = policy._forward_eval(minibatch['obs'])
        pred_action = policy_output['action']
        total_accuracy = (pred_action == minibatch['action'].view(-1)).float().mean()
        total_accuracy_in_dataset.append(total_accuracy)

        for action_unique in to_list(torch.unique(minibatch['action'])):
            # find the index where action is `action_unique` in `pred_action`
            action_index = (pred_action == action_unique).nonzero(as_tuple=True)[0]
            action_accuracy = (pred_action[action_index] == minibatch['action'].view(-1)[action_index]).float().mean()
            if math.isnan(action_accuracy):
                action_accuracy = 0.0
            action_accuracy_in_dataset[action_unique].append(action_accuracy)
            # logging.info(f'the accuracy of action {action_unique} in current train mini-batch is: {action_accuracy}')

    logging.info(f'total accuracy in dataset is: {torch.tensor(total_accuracy_in_dataset).mean().item()}')
    logging.info(
        f'accuracy of each action in dataset is (nan means the action does not appear in the dataset): '
        f'{ {k: torch.tensor(action_accuracy_in_dataset[k]).mean().item() for k in range(19)} }'
    )