File size: 2,504 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from easydict import EasyDict
from copy import deepcopy
bipedalwalker_dt_config = dict(
exp_name='bipedalwalker_dt_1000eps_seed0',
env=dict(
env_name='BipedalWalker-v3',
collector_env_num=8,
evaluator_env_num=5,
act_scale=True,
n_evaluator_episode=5,
stop_value=300, # stop when return arrive 300
rew_clip=True, # reward clip
replay_path=None,
),
policy=dict(
stop_value=300,
device='cuda',
env_name='BipedalWalker-v3',
rtg_target=300, # max target return to go
max_eval_ep_len=1000, # max lenght of one episode
num_eval_ep=10, # num of evaluation episode
batch_size=64,
wt_decay=1e-4,
warmup_steps=10000,
num_updates_per_iter=100,
context_len=20,
n_blocks=3,
embed_dim=128,
n_heads=1,
dropout_p=0.1,
log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/box2d/bipedalwalker/dt_data/dt_log_1000eps',
model=dict(
state_dim=24,
act_dim=4,
n_blocks=3,
h_dim=128,
context_len=20,
n_heads=1,
drop_p=0.1,
continuous=True,
),
discount_factor=0.999,
nstep=3,
learn=dict(
dataset_path='/home/wangzilin/research/dt/sac_data_1000eps.pkl',
learning_rate=0.0001,
target_update_freq=100,
kappa=1.0,
min_q_weight=4.0,
),
collect=dict(unroll_len=1, ),
eval=dict(evaluator=dict(evalu_freq=100, ), ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=10000,
),
replay_buffer=dict(replay_buffer_size=1000, ),
),
),
)
bipedalwalker_dt_config = EasyDict(bipedalwalker_dt_config)
main_config = bipedalwalker_dt_config
bipedalwalker_dt_create_config = dict(
env=dict(
type='bipedalwalker',
import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dt'),
)
bipedalwalker_dt_create_config = EasyDict(bipedalwalker_dt_create_config)
create_config = bipedalwalker_dt_create_config
if __name__ == "__main__":
from ding.entry import serial_pipeline_dt
config = deepcopy([main_config, create_config])
serial_pipeline_dt(config, seed=0, max_train_iter=1000)
|