File size: 2,449 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from easydict import EasyDict
from copy import deepcopy
halfcheetah_dt_config = dict(
exp_name='halfcheetah_random_dt_seed0',
env=dict(
env_id='HalfCheetah-v3',
norm_obs=dict(use_norm=False, ),
norm_reward=dict(use_norm=False, ),
collector_env_num=1,
evaluator_env_num=8,
use_act_scale=True,
n_evaluator_episode=8,
stop_value=6000,
),
policy=dict(
stop_value=6000,
cuda=True,
env_name='HalfCheetah-v3',
rtg_target=6000, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
num_eval_ep=10, # num of evaluation episode
batch_size=64,
wt_decay=1e-4,
warmup_steps=10000,
num_updates_per_iter=100,
context_len=20,
n_blocks=3,
embed_dim=128,
n_heads=1,
dropout_p=0.1,
log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/halfcheetah_random_dt_log',
model=dict(
state_dim=17,
act_dim=6,
n_blocks=3,
h_dim=128,
context_len=20,
n_heads=1,
drop_p=0.1,
continuous=True,
),
discount_factor=0.999,
nstep=3,
learn=dict(
dataset_path='/mnt/lustre/wangzilin/d4rl_data/halfcheetah-random-v2.pkl',
learning_rate=0.0001,
target_update_freq=100,
kappa=1.0,
min_q_weight=4.0
),
collect=dict(unroll_len=1, ),
eval=dict(evaluator=dict(evalu_freq=100, ), ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=1000, ),
),
),
)
halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
main_config = halfcheetah_dt_config
halfcheetah_dt_create_config = dict(
env=dict(
type='mujoco',
import_names=['dizoo.mujoco.envs.mujoco_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dt'),
)
halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
create_config = halfcheetah_dt_create_config
if __name__ == "__main__":
from ding.entry import serial_pipeline_dt
config = deepcopy([main_config, create_config])
serial_pipeline_dt(config, seed=0, max_train_iter=1000)
|