zjowowen's picture
init space
079c32c
raw
history blame
No virus
1.88 kB
import torch
import logging
import math
from ding.torch_utils import to_list
from ding.utils.data import NaiveRLDataset
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
def test_accuracy_in_dataset(data_path, batch_size, policy):
"""
Overview:
Evaluate total accuracy and accuracy of each action in dataset from
``datapath`` using the ``policy`` for gfootball env.
"""
dataset = NaiveRLDataset(data_path)
dataloader = DataLoader(dataset, batch_size)
total_accuracy_in_dataset = []
action_accuracy_in_dataset = {k: [] for k in range(19)}
for _, minibatch in enumerate(dataloader):
policy_output = policy._forward_eval(minibatch['obs'])
pred_action = policy_output['action']
total_accuracy = (pred_action == minibatch['action'].view(-1)).float().mean()
total_accuracy_in_dataset.append(total_accuracy)
for action_unique in to_list(torch.unique(minibatch['action'])):
# find the index where action is `action_unique` in `pred_action`
action_index = (pred_action == action_unique).nonzero(as_tuple=True)[0]
action_accuracy = (pred_action[action_index] == minibatch['action'].view(-1)[action_index]).float().mean()
if math.isnan(action_accuracy):
action_accuracy = 0.0
action_accuracy_in_dataset[action_unique].append(action_accuracy)
# logging.info(f'the accuracy of action {action_unique} in current train mini-batch is: {action_accuracy}')
logging.info(f'total accuracy in dataset is: {torch.tensor(total_accuracy_in_dataset).mean().item()}')
logging.info(
f'accuracy of each action in dataset is (nan means the action does not appear in the dataset): '
f'{ {k: torch.tensor(action_accuracy_in_dataset[k]).mean().item() for k in range(19)} }'
)