from ding.entry import serial_pipeline_bc, serial_pipeline, collect_demo_data from dizoo.mujoco.config.halfcheetah_td3_config import main_config, create_config from copy import deepcopy from typing import Union, Optional, List, Any, Tuple import os import torch import logging from functools import partial from tensorboardX import SummaryWriter import torch.nn as nn from ding.envs import get_vec_env_setting, create_env_manager from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ create_serial_collector from ding.config import read_config, compile_config from ding.policy import create_policy from ding.utils import set_pkg_seed from ding.entry.utils import random_collect from ding.entry import collect_demo_data, collect_episodic_demo_data, episode_to_transitions import pickle def load_policy( input_cfg: Union[str, Tuple[dict, dict]], load_path: str, seed: int = 0, env_setting: Optional[List[Any]] = None, model: Optional[torch.nn.Module] = None, ) -> 'Policy': # noqa if isinstance(input_cfg, str): cfg, create_cfg = read_config(input_cfg) else: cfg, create_cfg = input_cfg create_cfg.policy.type = create_cfg.policy.type + '_command' env_fn = None if env_setting is None else env_setting[0] cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) sd = torch.load(load_path, map_location='cpu') policy.collect_mode.load_state_dict(sd) return policy def main(): half_td3_config, half_td3_create_config = main_config, create_config train_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] exp_path = 'DI-engine/halfcheetah_td3_seed0/ckpt/ckpt_best.pth.tar' expert_policy = load_policy(train_config, load_path=exp_path, seed=0) # collect expert demo data collect_count = 100 expert_data_path = 'expert_data.pkl' state_dict = expert_policy.collect_mode.state_dict() collect_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] collect_episodic_demo_data( deepcopy(collect_config), seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count ) episode_to_transitions(expert_data_path, expert_data_path, nstep=1) # il training 2 il_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] il_config[0].policy.learn.train_epoch = 1000000 il_config[0].policy.type = 'bc' il_config[0].policy.continuous = True il_config[0].exp_name = "continuous_bc_seed0" il_config[0].env.stop_value = 50000 il_config[0].multi_agent = False bc_policy, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path, max_iter=4e6) return bc_policy if __name__ == '__main__': policy = main()