gomoku / DI-engine /dizoo /metadrive /config /metadrive_onppo_eval_config.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
3.69 kB
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter
import torch
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
from ding.config import compile_config
from ding.model.template import VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
# Load the trained model from this direction, if None, it will initialize from scratch
model_dir = None
metadrive_basic_config = dict(
exp_name='metadrive_onppo_eval_seed0',
env=dict(
metadrive=dict(
use_render=True,
traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1]
map='XSOS', # Int or string: an easy way to fill map_config
horizon=4000, # Max step number
driving_reward=1.0, # Reward to encourage agent to move forward.
speed_reward=0.10, # Reward to encourage agent to drive at a high speed
use_lateral_reward=False, # reward for lane keeping
out_of_road_penalty=40.0, # Penalty to discourage driving out of road
crash_vehicle_penalty=40.0, # Penalty to discourage collision
decision_repeat=20, # Reciprocal of decision frequency
out_of_route_done=True, # Game over if driving out of road
show_bird_view=False, # Only used to evaluate, whether to draw five channels of bird-view image
),
manager=dict(
shared_memory=False,
max_retry=2,
context='spawn',
),
n_evaluator_episode=16,
stop_value=255,
collector_env_num=1,
evaluator_env_num=1,
),
policy=dict(
cuda=True,
action_space='continuous',
model=dict(
obs_shape=[5, 84, 84],
action_shape=2,
action_space='continuous',
bound_type='tanh',
encoder_hidden_size_list=[128, 128, 64],
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight=0.001,
value_weight=0.5,
clip_ratio=0.02,
adv_norm=False,
value_norm=True,
grad_clip_value=10,
),
collect=dict(n_sample=1000, ),
eval=dict(evaluator=dict(eval_freq=1000, ), ),
),
)
main_config = EasyDict(metadrive_basic_config)
def wrapped_env(env_cfg, wrapper_cfg=None):
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)
def main(cfg):
cfg = compile_config(cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator)
evaluator_env_num = cfg.env.evaluator_env_num
show_bird_view = cfg.env.metadrive.show_bird_view
wrapper_cfg = {'show_bird_view': show_bird_view}
evaluator_env = BaseEnvManager(
env_fn=[partial(wrapped_env, cfg.env.metadrive, wrapper_cfg) for _ in range(evaluator_env_num)],
cfg=cfg.env.manager,
)
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)
if model_dir is not None:
policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu'))
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)
stop, rate = evaluator.eval()
evaluator.close()
if __name__ == '__main__':
main(main_config)