gomoku / DI-engine /dizoo /gfootball /entry /gfootball_bc_kaggle5th_main.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
2.54 kB
"""
Overview:
Here is the behaviour cloning (BC) main entry for gfootball.
We first collect demo data using ``FootballKaggle5thPlaceModel``, then train the BC model
using the collected demo data,
and (optional) test accuracy in train dataset and test dataset of the trained BC model
"""
from copy import deepcopy
import os
from ding.entry import serial_pipeline_bc, collect_demo_data
from ding.config import read_config, compile_config
from ding.policy import create_policy
from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config
from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ
from dizoo.gfootball.model.bots.kaggle_5th_place_model import FootballKaggle5thPlaceModel
path = os.path.abspath(__file__)
dir_path = os.path.dirname(path)
# in gfootball env: 3000 transitions = one episode
# 3e5 transitions = 100 episode, The memory needs about 180G
seed = 0
gfootball_bc_config.exp_name = 'gfootball_bc_kaggle5th_seed0'
demo_transitions = int(3e5) # key hyper-parameter
data_path_transitions = dir_path + f'/gfootball_kaggle5th_{demo_transitions}-demo-transitions.pkl'
"""
phase 1: collect demo data utilizing ``FootballKaggle5thPlaceModel``
"""
train_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
input_cfg = train_config
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type = create_cfg.policy.type + '_command'
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
football_kaggle_5th_place_model = FootballKaggle5thPlaceModel()
expert_policy = create_policy(
cfg.policy, model=football_kaggle_5th_place_model, enable_field=['learn', 'collect', 'eval', 'command']
)
# collect expert demo data
state_dict = expert_policy.collect_mode.state_dict()
collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
collect_demo_data(
collect_config,
seed=seed,
expert_data_path=data_path_transitions,
collect_count=demo_transitions,
model=football_kaggle_5th_place_model,
state_dict=state_dict,
)
"""
phase 2: BC training
"""
bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
bc_config[0].policy.learn.train_epoch = 1000 # key hyper-parameter
football_naive_q = FootballNaiveQ()
_, converge_stop_flag = serial_pipeline_bc(
bc_config, seed=seed, data_path=data_path_transitions, model=football_naive_q
)