|
""" |
|
Overview: |
|
Here is the behaviour cloning (BC) main entry for gfootball. |
|
We first collect demo data using rule model, then train the BC model |
|
using the demo data whose return is larger than 0, |
|
and (optional) test accuracy in train dataset and test dataset of the trained BC model |
|
""" |
|
from copy import deepcopy |
|
import os |
|
import torch |
|
import logging |
|
import test_accuracy |
|
from ding.entry import serial_pipeline_bc, collect_episodic_demo_data, episode_to_transitions_filter, eval |
|
from ding.config import read_config, compile_config |
|
from ding.policy import create_policy |
|
from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config |
|
from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ |
|
from dizoo.gfootball.model.bots.rule_based_bot_model import FootballRuleBaseModel |
|
|
|
path = os.path.abspath(__file__) |
|
dir_path = os.path.dirname(path) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
seed = 0 |
|
gfootball_bc_config.exp_name = 'gfootball_bc_rule_200ep_lt0_seed0' |
|
demo_episodes = 200 |
|
data_path_episode = dir_path + f'/gfootball_rule_{demo_episodes}eps.pkl' |
|
data_path_transitions_lt0 = dir_path + f'/gfootball_rule_{demo_episodes}eps_transitions_lt0.pkl' |
|
""" |
|
phase 1: collect demo data utilizing rule model |
|
""" |
|
input_cfg = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] |
|
if isinstance(input_cfg, str): |
|
cfg, create_cfg = read_config(input_cfg) |
|
else: |
|
cfg, create_cfg = input_cfg |
|
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) |
|
|
|
football_rule_base_model = FootballRuleBaseModel() |
|
expert_policy = create_policy(cfg.policy, model=football_rule_base_model, enable_field=['learn', 'collect', 'eval']) |
|
|
|
|
|
state_dict = expert_policy.collect_mode.state_dict() |
|
collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] |
|
eval_config = deepcopy(collect_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
collect_episodic_demo_data( |
|
collect_config, |
|
seed=seed, |
|
expert_data_path=data_path_episode, |
|
collect_count=demo_episodes, |
|
model=football_rule_base_model, |
|
state_dict=state_dict |
|
) |
|
|
|
episode_to_transitions_filter( |
|
data_path=data_path_episode, expert_data_path=data_path_transitions_lt0, nstep=1, min_episode_return=1 |
|
) |
|
""" |
|
phase 2: BC training |
|
""" |
|
bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] |
|
bc_config[0].policy.learn.train_epoch = 1000 |
|
football_naive_q = FootballNaiveQ() |
|
|
|
_, converge_stop_flag = serial_pipeline_bc( |
|
bc_config, seed=seed, data_path=data_path_transitions_lt0, model=football_naive_q |
|
) |
|
|
|
if bc_config[0].policy.show_train_test_accuracy: |
|
""" |
|
phase 3: test accuracy in train dataset and test dataset |
|
""" |
|
bc_model_path = bc_config[0].policy.bc_model_path |
|
|
|
|
|
bc_config[0].policy.learn.batch_size = int(3000) |
|
state_dict = torch.load(bc_model_path) |
|
football_naive_q.load_state_dict(state_dict['model']) |
|
policy = create_policy(cfg.policy, model=football_naive_q, enable_field=['eval']) |
|
|
|
|
|
print('==' * 10) |
|
print('calculate accuracy in train dataset') |
|
print('==' * 10) |
|
|
|
train_data_path = dir_path + f'/gfootball_rule_100eps_transitions_lt0_train.pkl' |
|
test_accuracy.test_accuracy_in_dataset(train_data_path, cfg.policy.learn.batch_size, policy) |
|
|
|
|
|
print('==' * 10) |
|
print('calculate accuracy in test dataset') |
|
print('==' * 10) |
|
|
|
test_data_path = dir_path + f'/gfootball_rule_50eps_transitions_lt0_test.pkl' |
|
test_accuracy.test_accuracy_in_dataset(test_data_path, cfg.policy.learn.batch_size, policy) |
|
|