""" Overview: Here is the behaviour cloning (BC) main entry for gfootball. We first collect demo data using rule model, then train the BC model using the demo data whose return is larger than 0, and (optional) test accuracy in train dataset and test dataset of the trained BC model """ from copy import deepcopy import os import torch import logging import test_accuracy from ding.entry import serial_pipeline_bc, collect_episodic_demo_data, episode_to_transitions_filter, eval from ding.config import read_config, compile_config from ding.policy import create_policy from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ from dizoo.gfootball.model.bots.rule_based_bot_model import FootballRuleBaseModel path = os.path.abspath(__file__) dir_path = os.path.dirname(path) logging.basicConfig(level=logging.INFO) # Note: in gfootball env, 3000 transitions = one episode, # 3e5 transitions = 200 episode, the memory needs about 350G. seed = 0 gfootball_bc_config.exp_name = 'gfootball_bc_rule_200ep_lt0_seed0' demo_episodes = 200 # key hyper-parameter data_path_episode = dir_path + f'/gfootball_rule_{demo_episodes}eps.pkl' data_path_transitions_lt0 = dir_path + f'/gfootball_rule_{demo_episodes}eps_transitions_lt0.pkl' """ phase 1: collect demo data utilizing rule model """ input_cfg = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] if isinstance(input_cfg, str): cfg, create_cfg = read_config(input_cfg) else: cfg, create_cfg = input_cfg cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) football_rule_base_model = FootballRuleBaseModel() expert_policy = create_policy(cfg.policy, model=football_rule_base_model, enable_field=['learn', 'collect', 'eval']) # collect rule/expert demo data state_dict = expert_policy.collect_mode.state_dict() collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] eval_config = deepcopy(collect_config) # eval demo model # if save replay # eval(eval_config, seed=seed, model=football_rule_base_model, replay_path=dir_path + f'/gfootball_rule_replay/') # if not save replay # eval(eval_config, seed=seed, model=football_rule_base_model, state_dict=state_dict) # collect demo data collect_episodic_demo_data( collect_config, seed=seed, expert_data_path=data_path_episode, collect_count=demo_episodes, model=football_rule_base_model, state_dict=state_dict ) # Note: only use the episode whose return is larger than 0 as demo data episode_to_transitions_filter( data_path=data_path_episode, expert_data_path=data_path_transitions_lt0, nstep=1, min_episode_return=1 ) """ phase 2: BC training """ bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] bc_config[0].policy.learn.train_epoch = 1000 # key hyper-parameter football_naive_q = FootballNaiveQ() _, converge_stop_flag = serial_pipeline_bc( bc_config, seed=seed, data_path=data_path_transitions_lt0, model=football_naive_q ) if bc_config[0].policy.show_train_test_accuracy: """ phase 3: test accuracy in train dataset and test dataset """ bc_model_path = bc_config[0].policy.bc_model_path # load trained model bc_config[0].policy.learn.batch_size = int(3000) # the total dataset state_dict = torch.load(bc_model_path) football_naive_q.load_state_dict(state_dict['model']) policy = create_policy(cfg.policy, model=football_naive_q, enable_field=['eval']) # calculate accuracy in train dataset print('==' * 10) print('calculate accuracy in train dataset') print('==' * 10) # Users should add their own bc train_data_path here. Absolute path is recommended. train_data_path = dir_path + f'/gfootball_rule_100eps_transitions_lt0_train.pkl' test_accuracy.test_accuracy_in_dataset(train_data_path, cfg.policy.learn.batch_size, policy) # calculate accuracy in test dataset print('==' * 10) print('calculate accuracy in test dataset') print('==' * 10) # Users should add their own bc test_data_path here. Absolute path is recommended. test_data_path = dir_path + f'/gfootball_rule_50eps_transitions_lt0_test.pkl' test_accuracy.test_accuracy_in_dataset(test_data_path, cfg.policy.learn.batch_size, policy)