File size: 3,693 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from easydict import EasyDict
from functools import partial
from tensorboardX import SummaryWriter
import torch
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
from ding.config import compile_config
from ding.model.template import VAC
from ding.policy import PPOPolicy
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper

# Load the trained model from this direction, if None, it will initialize from scratch
model_dir = None
metadrive_basic_config = dict(
    exp_name='metadrive_onppo_eval_seed0',
    env=dict(
        metadrive=dict(
            use_render=True,
            traffic_density=0.10,  # Density of vehicles occupying the roads, range in [0,1]
            map='XSOS',  # Int or string: an easy way to fill map_config
            horizon=4000,  # Max step number
            driving_reward=1.0,  # Reward to encourage agent to move forward.
            speed_reward=0.10,  # Reward to encourage agent to drive at a high speed
            use_lateral_reward=False,  # reward for lane keeping
            out_of_road_penalty=40.0,  # Penalty to discourage driving out of road
            crash_vehicle_penalty=40.0,  # Penalty to discourage collision
            decision_repeat=20,  # Reciprocal of decision frequency
            out_of_route_done=True,  # Game over if driving out of road
            show_bird_view=False,  # Only used to evaluate, whether to draw five channels of bird-view image
        ),
        manager=dict(
            shared_memory=False,
            max_retry=2,
            context='spawn',
        ),
        n_evaluator_episode=16,
        stop_value=255,
        collector_env_num=1,
        evaluator_env_num=1,
    ),
    policy=dict(
        cuda=True,
        action_space='continuous',
        model=dict(
            obs_shape=[5, 84, 84],
            action_shape=2,
            action_space='continuous',
            bound_type='tanh',
            encoder_hidden_size_list=[128, 128, 64],
        ),
        learn=dict(
            epoch_per_collect=10,
            batch_size=64,
            learning_rate=3e-4,
            entropy_weight=0.001,
            value_weight=0.5,
            clip_ratio=0.02,
            adv_norm=False,
            value_norm=True,
            grad_clip_value=10,
        ),
        collect=dict(n_sample=1000, ),
        eval=dict(evaluator=dict(eval_freq=1000, ), ),
    ),
)
main_config = EasyDict(metadrive_basic_config)


def wrapped_env(env_cfg, wrapper_cfg=None):
    return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)


def main(cfg):
    cfg = compile_config(cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator)
    evaluator_env_num = cfg.env.evaluator_env_num
    show_bird_view = cfg.env.metadrive.show_bird_view
    wrapper_cfg = {'show_bird_view': show_bird_view}
    evaluator_env = BaseEnvManager(
        env_fn=[partial(wrapped_env, cfg.env.metadrive, wrapper_cfg) for _ in range(evaluator_env_num)],
        cfg=cfg.env.manager,
    )
    model = VAC(**cfg.policy.model)
    policy = PPOPolicy(cfg.policy, model=model)
    if model_dir is not None:
        policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu'))
    tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
    evaluator = InteractionSerialEvaluator(
        cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
    )
    stop, rate = evaluator.eval()
    evaluator.close()


if __name__ == '__main__':
    main(main_config)