ledmands commited on
Commit
c37ff18
·
1 Parent(s): ebb75df

Modified watch_agent.py to include ability to give an argument to adjust repeat action probability.

Browse files
Files changed (1) hide show
  1. agents/version_2/watch_agent.py +16 -10
agents/version_2/watch_agent.py CHANGED
@@ -3,20 +3,26 @@ from stable_baselines3.common.evaluation import evaluate_policy
3
  from stable_baselines3.common.monitor import Monitor
4
  import gymnasium as gym
5
 
6
- MODEL_NAME = "ALE-Pacman-v5-control"
7
 
8
- # the saved model does not contain the replay buffer
9
- loaded_model = DQN.load(MODEL_NAME)
10
- # print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
 
 
 
 
 
11
 
12
- # now the loaded replay is not empty anymore
13
- # print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
14
 
 
15
 
16
  # Retrieve the environment
17
- eval_env = Monitor(gym.make("ALE/Pacman-v5", render_mode="human", ))
18
 
19
  # Evaluate the policy
20
- mean_reward, std_reward = evaluate_policy(loaded_model.policy, eval_env, n_eval_episodes=10, deterministic=False, )
21
-
22
- print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
 
3
  from stable_baselines3.common.monitor import Monitor
4
  import gymnasium as gym
5
 
6
+ import argparse
7
 
8
+ # This script should have some options
9
+ # 1. Turn off the stochasticity as determined by the ALEv5
10
+ # Even if deterministic is set to true in evaluate policy, the environment will ignore this 25% of the time
11
+ # To compensate for this, we can set the repeat action probability to 0
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("-r", "--repeat_action_probability", help="repeat action probability", type=float, default=0.25)
15
+ args = parser.parse_args()
16
 
17
+ MODEL_NAME = "ALE-Pacman-v5"
18
+ rpt_act_prob = args.repeat_action_probability
19
 
20
+ loaded_model = DQN.load(MODEL_NAME)
21
 
22
  # Retrieve the environment
23
+ eval_env = Monitor(gym.make("ALE/Pacman-v5", render_mode="rgb_array", repeat_action_probability=rpt_act_prob))
24
 
25
  # Evaluate the policy
26
+ mean_rwd, std_rwd = evaluate_policy(loaded_model.policy, eval_env, n_eval_episodes=1)
27
+ print("mean rwd: ", mean_rwd)
28
+ print("std rwd: ", std_rwd)