prashanthgowni commited on
Commit
c5855e2
1 Parent(s): 6f329c8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -3
README.md CHANGED
@@ -19,6 +19,8 @@ model-index:
19
  value: 277.82 +/- 22.28
20
  name: mean_reward
21
  verified: false
 
 
22
  ---
23
 
24
  # **PPO** Agent playing **LunarLander-v2**
@@ -30,8 +32,41 @@ TODO: Add your code
30
 
31
 
32
  ```python
33
- from stable_baselines3 import ...
 
 
 
34
  from huggingface_sb3 import load_from_hub
35
 
36
- ...
37
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  value: 277.82 +/- 22.28
20
  name: mean_reward
21
  verified: false
22
+ language:
23
+ - en
24
  ---
25
 
26
  # **PPO** Agent playing **LunarLander-v2**
 
32
 
33
 
34
  ```python
35
+ from stable_baselines3 import PPO
36
+ from stable_baselines3.common.env_util import make_vec_env
37
+ from stable_baselines3.common.evaluation import evaluate_policy
38
+
39
  from huggingface_sb3 import load_from_hub
40
 
41
+
42
+ # Download the model checkpoint
43
+ model_checkpoint = load_from_hub("prashanthgowni/ppo-LunarLander-v2", "ppo-LunarLander-v2")
44
+ # Create a vectorized environment
45
+ env = make_vec_env("LunarLander-v2", n_envs=1)
46
+
47
+ # Load the model
48
+ model = PPO.load(model_checkpoint, env=env)
49
+
50
+ # Evaluate
51
+ print("Evaluating model")
52
+ mean_reward, std_reward = evaluate_policy(
53
+ model,
54
+ env,
55
+ n_eval_episodes=30,
56
+ deterministic=True,
57
+ )
58
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward}")
59
+
60
+ # Start a new episode
61
+ obs = env.reset()
62
+
63
+ try:
64
+ while True:
65
+ action, state = model.predict(obs, deterministic=True)
66
+ obs, reward, done, info = env.step(action)
67
+ env.render()
68
+
69
+ except KeyboardInterrupt:
70
+ pass
71
+
72
+ ```