init space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- DI-engine +0 -1
- DI-engine/.flake8 +4 -0
- DI-engine/.gitignore +1431 -0
- DI-engine/.style.yapf +11 -0
- DI-engine/CHANGELOG +489 -0
- DI-engine/CODE_OF_CONDUCT.md +128 -0
- DI-engine/CONTRIBUTING.md +7 -0
- DI-engine/LICENSE +202 -0
- DI-engine/Makefile +71 -0
- DI-engine/README.md +475 -0
- DI-engine/cloc.sh +69 -0
- DI-engine/codecov.yml +8 -0
- DI-engine/conda/conda_build_config.yaml +2 -0
- DI-engine/conda/meta.yaml +35 -0
- DI-engine/ding/__init__.py +12 -0
- DI-engine/ding/bonus/__init__.py +132 -0
- DI-engine/ding/bonus/a2c.py +460 -0
- DI-engine/ding/bonus/c51.py +459 -0
- DI-engine/ding/bonus/common.py +22 -0
- DI-engine/ding/bonus/config.py +326 -0
- DI-engine/ding/bonus/ddpg.py +456 -0
- DI-engine/ding/bonus/dqn.py +460 -0
- DI-engine/ding/bonus/model.py +245 -0
- DI-engine/ding/bonus/pg.py +453 -0
- DI-engine/ding/bonus/ppo_offpolicy.py +471 -0
- DI-engine/ding/bonus/ppof.py +509 -0
- DI-engine/ding/bonus/sac.py +457 -0
- DI-engine/ding/bonus/sql.py +461 -0
- DI-engine/ding/bonus/td3.py +455 -0
- DI-engine/ding/compatibility.py +9 -0
- DI-engine/ding/config/__init__.py +4 -0
- DI-engine/ding/config/config.py +579 -0
- DI-engine/ding/config/example/A2C/__init__.py +17 -0
- DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py +43 -0
- DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py +38 -0
- DI-engine/ding/config/example/C51/__init__.py +23 -0
- DI-engine/ding/config/example/C51/gym_lunarlander_v2.py +52 -0
- DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py +54 -0
- DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py +54 -0
- DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py +54 -0
- DI-engine/ding/config/example/DDPG/__init__.py +29 -0
- DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py +45 -0
- DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py +53 -0
- DI-engine/ding/config/example/DDPG/gym_hopper_v3.py +53 -0
- DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py +60 -0
- DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py +52 -0
- DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py +53 -0
- DI-engine/ding/config/example/DQN/__init__.py +23 -0
- DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py +53 -0
- DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py +50 -0
DI-engine
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit a57bc3024b938c881aaf6511d1fb26296cd98601
|
|
|
|
DI-engine/.flake8
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122,E731
|
3 |
+
max-line-length=120
|
4 |
+
statistics
|
DI-engine/.gitignore
ADDED
@@ -0,0 +1,1431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
### ArchLinuxPackages template
|
3 |
+
*.tar
|
4 |
+
*.tar.*
|
5 |
+
*.jar
|
6 |
+
*.exe
|
7 |
+
*.msi
|
8 |
+
*.zip
|
9 |
+
*.tgz
|
10 |
+
*.log
|
11 |
+
*.log.*
|
12 |
+
*.sig
|
13 |
+
*.mov
|
14 |
+
*.pkl
|
15 |
+
|
16 |
+
pkg/
|
17 |
+
src/
|
18 |
+
impala_log/
|
19 |
+
|
20 |
+
### CVS template
|
21 |
+
/CVS/*
|
22 |
+
**/CVS/*
|
23 |
+
.cvsignore
|
24 |
+
*/.cvsignore
|
25 |
+
|
26 |
+
### LibreOffice template
|
27 |
+
# LibreOffice locks
|
28 |
+
.~lock.*#
|
29 |
+
|
30 |
+
### CUDA template
|
31 |
+
*.i
|
32 |
+
*.ii
|
33 |
+
*.gpu
|
34 |
+
*.ptx
|
35 |
+
*.cubin
|
36 |
+
*.fatbin
|
37 |
+
|
38 |
+
### Eclipse template
|
39 |
+
*.bin
|
40 |
+
.metadata
|
41 |
+
bin/
|
42 |
+
tmp/
|
43 |
+
*.tmp
|
44 |
+
*.bak
|
45 |
+
*.swp
|
46 |
+
*~.nib
|
47 |
+
local.properties
|
48 |
+
.settings/
|
49 |
+
.loadpath
|
50 |
+
.recommenders
|
51 |
+
|
52 |
+
# External tool builders
|
53 |
+
.externalToolBuilders/
|
54 |
+
|
55 |
+
# Locally stored "Eclipse launch configurations"
|
56 |
+
*.launch
|
57 |
+
|
58 |
+
# PyDev specific (Python IDE for Eclipse)
|
59 |
+
*.pydevproject
|
60 |
+
|
61 |
+
# CDT-specific (C/C++ Development Tooling)
|
62 |
+
.cproject
|
63 |
+
|
64 |
+
# CDT- autotools
|
65 |
+
.autotools
|
66 |
+
|
67 |
+
# Java annotation processor (APT)
|
68 |
+
.factorypath
|
69 |
+
|
70 |
+
# PDT-specific (PHP Development Tools)
|
71 |
+
.buildpath
|
72 |
+
|
73 |
+
# sbteclipse plugin
|
74 |
+
.target
|
75 |
+
|
76 |
+
# Tern plugin
|
77 |
+
.tern-project
|
78 |
+
|
79 |
+
# TeXlipse plugin
|
80 |
+
.texlipse
|
81 |
+
|
82 |
+
# STS (Spring Tool Suite)
|
83 |
+
.springBeans
|
84 |
+
|
85 |
+
# Code Recommenders
|
86 |
+
.recommenders/
|
87 |
+
|
88 |
+
# Annotation Processing
|
89 |
+
.apt_generated/
|
90 |
+
.apt_generated_test/
|
91 |
+
|
92 |
+
# Scala IDE specific (Scala & Java development for Eclipse)
|
93 |
+
.cache-main
|
94 |
+
.scala_dependencies
|
95 |
+
.worksheet
|
96 |
+
|
97 |
+
# Uncomment this line if you wish to ignore the project description file.
|
98 |
+
# Typically, this file would be tracked if it contains build/dependency configurations:
|
99 |
+
#.project
|
100 |
+
|
101 |
+
### SVN template
|
102 |
+
.svn/
|
103 |
+
|
104 |
+
### Images template
|
105 |
+
# JPEG
|
106 |
+
*.jpg
|
107 |
+
*.jpeg
|
108 |
+
*.jpe
|
109 |
+
*.jif
|
110 |
+
*.jfif
|
111 |
+
*.jfi
|
112 |
+
|
113 |
+
# JPEG 2000
|
114 |
+
*.jp2
|
115 |
+
*.j2k
|
116 |
+
*.jpf
|
117 |
+
*.jpx
|
118 |
+
*.jpm
|
119 |
+
*.mj2
|
120 |
+
|
121 |
+
# JPEG XR
|
122 |
+
*.jxr
|
123 |
+
*.hdp
|
124 |
+
*.wdp
|
125 |
+
|
126 |
+
# Graphics Interchange Format
|
127 |
+
*.gif
|
128 |
+
*.mp4
|
129 |
+
*.mpg
|
130 |
+
|
131 |
+
# RAW
|
132 |
+
*.raw
|
133 |
+
|
134 |
+
# Web P
|
135 |
+
*.webp
|
136 |
+
|
137 |
+
# Portable Network Graphics
|
138 |
+
*.png
|
139 |
+
|
140 |
+
# Animated Portable Network Graphics
|
141 |
+
*.apng
|
142 |
+
|
143 |
+
# Multiple-image Network Graphics
|
144 |
+
*.mng
|
145 |
+
|
146 |
+
# Tagged Image File Format
|
147 |
+
*.tiff
|
148 |
+
*.tif
|
149 |
+
|
150 |
+
# Scalable Vector Graphics
|
151 |
+
*.svg
|
152 |
+
*.svgz
|
153 |
+
|
154 |
+
# Portable Document Format
|
155 |
+
*.pdf
|
156 |
+
|
157 |
+
# X BitMap
|
158 |
+
*.xbm
|
159 |
+
|
160 |
+
# BMP
|
161 |
+
*.bmp
|
162 |
+
*.dib
|
163 |
+
|
164 |
+
# ICO
|
165 |
+
*.ico
|
166 |
+
|
167 |
+
# 3D Images
|
168 |
+
*.3dm
|
169 |
+
*.max
|
170 |
+
|
171 |
+
### Diff template
|
172 |
+
*.patch
|
173 |
+
*.diff
|
174 |
+
|
175 |
+
### JetBrains template
|
176 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
177 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
178 |
+
|
179 |
+
# User-specific stuff
|
180 |
+
.idea/**/workspace.xml
|
181 |
+
.idea/**/tasks.xml
|
182 |
+
.idea/**/usage.statistics.xml
|
183 |
+
.idea/**/dictionaries
|
184 |
+
.idea/**/shelf
|
185 |
+
|
186 |
+
# Generated files
|
187 |
+
.idea/**/contentModel.xml
|
188 |
+
|
189 |
+
# Sensitive or high-churn files
|
190 |
+
.idea/**/dataSources/
|
191 |
+
.idea/**/dataSources.ids
|
192 |
+
.idea/**/dataSources.local.xml
|
193 |
+
.idea/**/sqlDataSources.xml
|
194 |
+
.idea/**/dynamic.xml
|
195 |
+
.idea/**/uiDesigner.xml
|
196 |
+
.idea/**/dbnavigator.xml
|
197 |
+
|
198 |
+
# Gradle
|
199 |
+
.idea/**/gradle.xml
|
200 |
+
.idea/**/libraries
|
201 |
+
|
202 |
+
# Gradle and Maven with auto-import
|
203 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
204 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
205 |
+
# auto-import.
|
206 |
+
# .idea/artifacts
|
207 |
+
# .idea/compiler.xml
|
208 |
+
# .idea/jarRepositories.xml
|
209 |
+
# .idea/modules.xml
|
210 |
+
# .idea/*.iml
|
211 |
+
# .idea/modules
|
212 |
+
# *.iml
|
213 |
+
# *.ipr
|
214 |
+
|
215 |
+
# CMake
|
216 |
+
cmake-build-*/
|
217 |
+
|
218 |
+
# Mongo Explorer plugin
|
219 |
+
.idea/**/mongoSettings.xml
|
220 |
+
|
221 |
+
# File-based project format
|
222 |
+
*.iws
|
223 |
+
|
224 |
+
# IntelliJ
|
225 |
+
out/
|
226 |
+
|
227 |
+
# mpeltonen/sbt-idea plugin
|
228 |
+
.idea_modules/
|
229 |
+
|
230 |
+
# JIRA plugin
|
231 |
+
atlassian-ide-plugin.xml
|
232 |
+
|
233 |
+
# Cursive Clojure plugin
|
234 |
+
.idea/replstate.xml
|
235 |
+
|
236 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
237 |
+
com_crashlytics_export_strings.xml
|
238 |
+
crashlytics.properties
|
239 |
+
crashlytics-build.properties
|
240 |
+
fabric.properties
|
241 |
+
|
242 |
+
# Editor-based Rest Client
|
243 |
+
.idea/httpRequests
|
244 |
+
|
245 |
+
# Android studio 3.1+ serialized cache file
|
246 |
+
.idea/caches/build_file_checksums.ser
|
247 |
+
|
248 |
+
### CodeIgniter template
|
249 |
+
*/config/development
|
250 |
+
*/logs/log-*.php
|
251 |
+
!*/logs/index.html
|
252 |
+
*/cache/*
|
253 |
+
!*/cache/index.html
|
254 |
+
!*/cache/.htaccess
|
255 |
+
|
256 |
+
user_guide_src/build/*
|
257 |
+
user_guide_src/cilexer/build/*
|
258 |
+
user_guide_src/cilexer/dist/*
|
259 |
+
user_guide_src/cilexer/pycilexer.egg-info/*
|
260 |
+
|
261 |
+
#codeigniter 3
|
262 |
+
application/logs/*
|
263 |
+
!application/logs/index.html
|
264 |
+
!application/logs/.htaccess
|
265 |
+
/vendor/
|
266 |
+
|
267 |
+
### Emacs template
|
268 |
+
# -*- mode: gitignore; -*-
|
269 |
+
*~
|
270 |
+
\#*\#
|
271 |
+
/.emacs.desktop
|
272 |
+
/.emacs.desktop.lock
|
273 |
+
*.elc
|
274 |
+
auto-save-list
|
275 |
+
tramp
|
276 |
+
.\#*
|
277 |
+
|
278 |
+
# Org-mode
|
279 |
+
.org-id-locations
|
280 |
+
*_archive
|
281 |
+
|
282 |
+
# flymake-mode
|
283 |
+
*_flymake.*
|
284 |
+
|
285 |
+
# eshell files
|
286 |
+
/eshell/history
|
287 |
+
/eshell/lastdir
|
288 |
+
|
289 |
+
# elpa packages
|
290 |
+
/elpa/
|
291 |
+
|
292 |
+
# reftex files
|
293 |
+
*.rel
|
294 |
+
|
295 |
+
# AUCTeX auto folder
|
296 |
+
/auto/
|
297 |
+
|
298 |
+
# cask packages
|
299 |
+
.cask/
|
300 |
+
dist/
|
301 |
+
|
302 |
+
# Flycheck
|
303 |
+
flycheck_*.el
|
304 |
+
|
305 |
+
# server auth directory
|
306 |
+
/server/
|
307 |
+
|
308 |
+
# projectiles files
|
309 |
+
.projectile
|
310 |
+
|
311 |
+
# directory configuration
|
312 |
+
.dir-locals.el
|
313 |
+
|
314 |
+
# network security
|
315 |
+
/network-security.data
|
316 |
+
|
317 |
+
|
318 |
+
### Windows template
|
319 |
+
# Windows thumbnail cache files
|
320 |
+
Thumbs.db
|
321 |
+
Thumbs.db:encryptable
|
322 |
+
ehthumbs.db
|
323 |
+
ehthumbs_vista.db
|
324 |
+
|
325 |
+
# Dump file
|
326 |
+
*.stackdump
|
327 |
+
|
328 |
+
# Folder config file
|
329 |
+
[Dd]esktop.ini
|
330 |
+
|
331 |
+
# Recycle Bin used on file shares
|
332 |
+
$RECYCLE.BIN/
|
333 |
+
|
334 |
+
# Windows Installer files
|
335 |
+
*.cab
|
336 |
+
*.msix
|
337 |
+
*.msm
|
338 |
+
*.msp
|
339 |
+
|
340 |
+
# Windows shortcuts
|
341 |
+
*.lnk
|
342 |
+
|
343 |
+
### VisualStudioCode template
|
344 |
+
.vscode/*
|
345 |
+
!.vscode/settings.json
|
346 |
+
!.vscode/tasks.json
|
347 |
+
!.vscode/launch.json
|
348 |
+
!.vscode/extensions.json
|
349 |
+
*.code-workspace
|
350 |
+
|
351 |
+
# Local History for Visual Studio Code
|
352 |
+
.history/
|
353 |
+
|
354 |
+
### CMake template
|
355 |
+
CMakeLists.txt.user
|
356 |
+
CMakeCache.txt
|
357 |
+
CMakeFiles
|
358 |
+
CMakeScripts
|
359 |
+
Testing
|
360 |
+
cmake_install.cmake
|
361 |
+
install_manifest.txt
|
362 |
+
compile_commands.json
|
363 |
+
CTestTestfile.cmake
|
364 |
+
_deps
|
365 |
+
|
366 |
+
### VisualStudio template
|
367 |
+
## Ignore Visual Studio temporary files, build results, and
|
368 |
+
## files generated by popular Visual Studio add-ons.
|
369 |
+
##
|
370 |
+
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
371 |
+
|
372 |
+
# User-specific files
|
373 |
+
*.rsuser
|
374 |
+
*.suo
|
375 |
+
*.user
|
376 |
+
*.userosscache
|
377 |
+
*.sln.docstates
|
378 |
+
|
379 |
+
# User-specific files (MonoDevelop/Xamarin Studio)
|
380 |
+
*.userprefs
|
381 |
+
|
382 |
+
# Mono auto generated files
|
383 |
+
mono_crash.*
|
384 |
+
|
385 |
+
# Build results
|
386 |
+
[Dd]ebug/
|
387 |
+
[Dd]ebugPublic/
|
388 |
+
[Rr]elease/
|
389 |
+
[Rr]eleases/
|
390 |
+
x64/
|
391 |
+
x86/
|
392 |
+
[Ww][Ii][Nn]32/
|
393 |
+
[Aa][Rr][Mm]/
|
394 |
+
[Aa][Rr][Mm]64/
|
395 |
+
bld/
|
396 |
+
[Bb]in/
|
397 |
+
[Oo]bj/
|
398 |
+
[Ll]og/
|
399 |
+
[Ll]ogs/
|
400 |
+
|
401 |
+
# Visual Studio 2015/2017 cache/options directory
|
402 |
+
.vs/
|
403 |
+
# Uncomment if you have tasks that create the project's static files in wwwroot
|
404 |
+
#wwwroot/
|
405 |
+
|
406 |
+
# Visual Studio 2017 auto generated files
|
407 |
+
Generated\ Files/
|
408 |
+
|
409 |
+
# MSTest test Results
|
410 |
+
[Tt]est[Rr]esult*/
|
411 |
+
[Bb]uild[Ll]og.*
|
412 |
+
|
413 |
+
# NUnit
|
414 |
+
*.VisualState.xml
|
415 |
+
TestResult.xml
|
416 |
+
nunit-*.xml
|
417 |
+
|
418 |
+
# Build Results of an ATL Project
|
419 |
+
[Dd]ebugPS/
|
420 |
+
[Rr]eleasePS/
|
421 |
+
dlldata.c
|
422 |
+
|
423 |
+
# Benchmark Results
|
424 |
+
BenchmarkDotNet.Artifacts/
|
425 |
+
|
426 |
+
# .NET Core
|
427 |
+
project.lock.json
|
428 |
+
project.fragment.lock.json
|
429 |
+
artifacts/
|
430 |
+
|
431 |
+
# ASP.NET Scaffolding
|
432 |
+
ScaffoldingReadMe.txt
|
433 |
+
|
434 |
+
# StyleCop
|
435 |
+
StyleCopReport.xml
|
436 |
+
|
437 |
+
# Files built by Visual Studio
|
438 |
+
*_i.c
|
439 |
+
*_p.c
|
440 |
+
*_h.h
|
441 |
+
*.ilk
|
442 |
+
*.meta
|
443 |
+
*.obj
|
444 |
+
*.iobj
|
445 |
+
*.pch
|
446 |
+
*.pdb
|
447 |
+
*.ipdb
|
448 |
+
*.pgc
|
449 |
+
*.pgd
|
450 |
+
*.rsp
|
451 |
+
*.sbr
|
452 |
+
*.tlb
|
453 |
+
*.tli
|
454 |
+
*.tlh
|
455 |
+
*.tmp_proj
|
456 |
+
*_wpftmp.csproj
|
457 |
+
*.vspscc
|
458 |
+
*.vssscc
|
459 |
+
.builds
|
460 |
+
*.pidb
|
461 |
+
*.svclog
|
462 |
+
*.scc
|
463 |
+
|
464 |
+
# Chutzpah Test files
|
465 |
+
_Chutzpah*
|
466 |
+
|
467 |
+
# Visual C++ cache files
|
468 |
+
ipch/
|
469 |
+
*.aps
|
470 |
+
*.ncb
|
471 |
+
*.opendb
|
472 |
+
*.opensdf
|
473 |
+
*.sdf
|
474 |
+
*.cachefile
|
475 |
+
*.VC.db
|
476 |
+
*.VC.VC.opendb
|
477 |
+
|
478 |
+
# Visual Studio profiler
|
479 |
+
*.psess
|
480 |
+
*.vsp
|
481 |
+
*.vspx
|
482 |
+
*.sap
|
483 |
+
|
484 |
+
# Visual Studio Trace Files
|
485 |
+
*.e2e
|
486 |
+
|
487 |
+
# TFS 2012 Local Workspace
|
488 |
+
$tf/
|
489 |
+
|
490 |
+
# Guidance Automation Toolkit
|
491 |
+
*.gpState
|
492 |
+
|
493 |
+
# ReSharper is a .NET coding add-in
|
494 |
+
_ReSharper*/
|
495 |
+
*.[Rr]e[Ss]harper
|
496 |
+
*.DotSettings.user
|
497 |
+
|
498 |
+
# TeamCity is a build add-in
|
499 |
+
_TeamCity*
|
500 |
+
|
501 |
+
# DotCover is a Code Coverage Tool
|
502 |
+
*.dotCover
|
503 |
+
|
504 |
+
# AxoCover is a Code Coverage Tool
|
505 |
+
.axoCover/*
|
506 |
+
!.axoCover/settings.json
|
507 |
+
|
508 |
+
# Coverlet is a free, cross platform Code Coverage Tool
|
509 |
+
coverage*.json
|
510 |
+
coverage*.xml
|
511 |
+
coverage*.info
|
512 |
+
|
513 |
+
# Visual Studio code coverage results
|
514 |
+
*.coverage
|
515 |
+
*.coveragexml
|
516 |
+
|
517 |
+
# NCrunch
|
518 |
+
_NCrunch_*
|
519 |
+
.*crunch*.local.xml
|
520 |
+
nCrunchTemp_*
|
521 |
+
|
522 |
+
# MightyMoose
|
523 |
+
*.mm.*
|
524 |
+
AutoTest.Net/
|
525 |
+
|
526 |
+
# Web workbench (sass)
|
527 |
+
.sass-cache/
|
528 |
+
|
529 |
+
# Installshield output folder
|
530 |
+
[Ee]xpress/
|
531 |
+
|
532 |
+
# DocProject is a documentation generator add-in
|
533 |
+
DocProject/buildhelp/
|
534 |
+
DocProject/Help/*.HxT
|
535 |
+
DocProject/Help/*.HxC
|
536 |
+
DocProject/Help/*.hhc
|
537 |
+
DocProject/Help/*.hhk
|
538 |
+
DocProject/Help/*.hhp
|
539 |
+
DocProject/Help/Html2
|
540 |
+
DocProject/Help/html
|
541 |
+
|
542 |
+
# Click-Once directory
|
543 |
+
publish/
|
544 |
+
|
545 |
+
# Publish Web Output
|
546 |
+
*.[Pp]ublish.xml
|
547 |
+
*.azurePubxml
|
548 |
+
# Note: Comment the next line if you want to checkin your web deploy settings,
|
549 |
+
# but database connection strings (with potential passwords) will be unencrypted
|
550 |
+
*.pubxml
|
551 |
+
*.publishproj
|
552 |
+
|
553 |
+
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
554 |
+
# checkin your Azure Web App publish settings, but sensitive information contained
|
555 |
+
# in these scripts will be unencrypted
|
556 |
+
PublishScripts/
|
557 |
+
|
558 |
+
# NuGet Packages
|
559 |
+
*.nupkg
|
560 |
+
# NuGet Symbol Packages
|
561 |
+
*.snupkg
|
562 |
+
# The packages folder can be ignored because of Package Restore
|
563 |
+
**/[Pp]ackages/*
|
564 |
+
# except build/, which is used as an MSBuild target.
|
565 |
+
!**/[Pp]ackages/build/
|
566 |
+
# Uncomment if necessary however generally it will be regenerated when needed
|
567 |
+
#!**/[Pp]ackages/repositories.config
|
568 |
+
# NuGet v3's project.json files produces more ignorable files
|
569 |
+
*.nuget.props
|
570 |
+
*.nuget.targets
|
571 |
+
|
572 |
+
# Microsoft Azure Build Output
|
573 |
+
csx/
|
574 |
+
*.build.csdef
|
575 |
+
|
576 |
+
# Microsoft Azure Emulator
|
577 |
+
ecf/
|
578 |
+
rcf/
|
579 |
+
|
580 |
+
# Windows Store app package directories and files
|
581 |
+
AppPackages/
|
582 |
+
BundleArtifacts/
|
583 |
+
Package.StoreAssociation.xml
|
584 |
+
_pkginfo.txt
|
585 |
+
*.appx
|
586 |
+
*.appxbundle
|
587 |
+
*.appxupload
|
588 |
+
|
589 |
+
# Visual Studio cache files
|
590 |
+
# files ending in .cache can be ignored
|
591 |
+
*.[Cc]ache
|
592 |
+
# but keep track of directories ending in .cache
|
593 |
+
!?*.[Cc]ache/
|
594 |
+
|
595 |
+
# Others
|
596 |
+
ClientBin/
|
597 |
+
~$*
|
598 |
+
*.dbmdl
|
599 |
+
*.dbproj.schemaview
|
600 |
+
*.jfm
|
601 |
+
*.pfx
|
602 |
+
*.publishsettings
|
603 |
+
orleans.codegen.cs
|
604 |
+
|
605 |
+
# Including strong name files can present a security risk
|
606 |
+
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
607 |
+
#*.snk
|
608 |
+
|
609 |
+
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
610 |
+
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
611 |
+
#bower_components/
|
612 |
+
|
613 |
+
# RIA/Silverlight projects
|
614 |
+
Generated_Code/
|
615 |
+
|
616 |
+
# Backup & report files from converting an old project file
|
617 |
+
# to a newer Visual Studio version. Backup files are not needed,
|
618 |
+
# because we have git ;-)
|
619 |
+
_UpgradeReport_Files/
|
620 |
+
Backup*/
|
621 |
+
UpgradeLog*.XML
|
622 |
+
UpgradeLog*.htm
|
623 |
+
ServiceFabricBackup/
|
624 |
+
*.rptproj.bak
|
625 |
+
|
626 |
+
# SQL Server files
|
627 |
+
*.mdf
|
628 |
+
*.ldf
|
629 |
+
*.ndf
|
630 |
+
|
631 |
+
# Business Intelligence projects
|
632 |
+
*.rdl.data
|
633 |
+
*.bim.layout
|
634 |
+
*.bim_*.settings
|
635 |
+
*.rptproj.rsuser
|
636 |
+
*- [Bb]ackup.rdl
|
637 |
+
*- [Bb]ackup ([0-9]).rdl
|
638 |
+
*- [Bb]ackup ([0-9][0-9]).rdl
|
639 |
+
|
640 |
+
# Microsoft Fakes
|
641 |
+
FakesAssemblies/
|
642 |
+
|
643 |
+
# GhostDoc plugin setting file
|
644 |
+
*.GhostDoc.xml
|
645 |
+
|
646 |
+
# Node.js Tools for Visual Studio
|
647 |
+
.ntvs_analysis.dat
|
648 |
+
node_modules/
|
649 |
+
|
650 |
+
# Visual Studio 6 build log
|
651 |
+
*.plg
|
652 |
+
|
653 |
+
# Visual Studio 6 workspace options file
|
654 |
+
*.opt
|
655 |
+
|
656 |
+
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
657 |
+
*.vbw
|
658 |
+
|
659 |
+
# Visual Studio LightSwitch build output
|
660 |
+
**/*.HTMLClient/GeneratedArtifacts
|
661 |
+
**/*.DesktopClient/GeneratedArtifacts
|
662 |
+
**/*.DesktopClient/ModelManifest.xml
|
663 |
+
**/*.Server/GeneratedArtifacts
|
664 |
+
**/*.Server/ModelManifest.xml
|
665 |
+
_Pvt_Extensions
|
666 |
+
|
667 |
+
# Paket dependency manager
|
668 |
+
.paket/paket.exe
|
669 |
+
paket-files/
|
670 |
+
|
671 |
+
# FAKE - F# Make
|
672 |
+
.fake/
|
673 |
+
|
674 |
+
# CodeRush personal settings
|
675 |
+
.cr/personal
|
676 |
+
|
677 |
+
# Python Tools for Visual Studio (PTVS)
|
678 |
+
__pycache__/
|
679 |
+
*.pyc
|
680 |
+
|
681 |
+
# Cake - Uncomment if you are using it
|
682 |
+
# tools/**
|
683 |
+
# !tools/packages.config
|
684 |
+
|
685 |
+
# Tabs Studio
|
686 |
+
*.tss
|
687 |
+
|
688 |
+
# Telerik's JustMock configuration file
|
689 |
+
*.jmconfig
|
690 |
+
|
691 |
+
# BizTalk build output
|
692 |
+
*.btp.cs
|
693 |
+
*.btm.cs
|
694 |
+
*.odx.cs
|
695 |
+
*.xsd.cs
|
696 |
+
|
697 |
+
# OpenCover UI analysis results
|
698 |
+
OpenCover/
|
699 |
+
|
700 |
+
# Azure Stream Analytics local run output
|
701 |
+
ASALocalRun/
|
702 |
+
|
703 |
+
# MSBuild Binary and Structured Log
|
704 |
+
*.binlog
|
705 |
+
|
706 |
+
# NVidia Nsight GPU debugger configuration file
|
707 |
+
*.nvuser
|
708 |
+
|
709 |
+
# MFractors (Xamarin productivity tool) working folder
|
710 |
+
.mfractor/
|
711 |
+
|
712 |
+
# Local History for Visual Studio
|
713 |
+
.localhistory/
|
714 |
+
|
715 |
+
# BeatPulse healthcheck temp database
|
716 |
+
healthchecksdb
|
717 |
+
|
718 |
+
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
719 |
+
MigrationBackup/
|
720 |
+
|
721 |
+
# Ionide (cross platform F# VS Code tools) working folder
|
722 |
+
.ionide/
|
723 |
+
|
724 |
+
# Fody - auto-generated XML schema
|
725 |
+
FodyWeavers.xsd
|
726 |
+
|
727 |
+
### Python template
|
728 |
+
# Byte-compiled / optimized / DLL files
|
729 |
+
*.py[cod]
|
730 |
+
*$py.class
|
731 |
+
|
732 |
+
# C extensions
|
733 |
+
*.so
|
734 |
+
|
735 |
+
# Distribution / packaging
|
736 |
+
.Python
|
737 |
+
build/
|
738 |
+
develop-eggs/
|
739 |
+
downloads/
|
740 |
+
eggs/
|
741 |
+
.eggs/
|
742 |
+
lib/
|
743 |
+
lib64/
|
744 |
+
parts/
|
745 |
+
sdist/
|
746 |
+
var/
|
747 |
+
wheels/
|
748 |
+
share/python-wheels/
|
749 |
+
*.egg-info/
|
750 |
+
.installed.cfg
|
751 |
+
*.egg
|
752 |
+
MANIFEST
|
753 |
+
|
754 |
+
# PyInstaller
|
755 |
+
# Usually these files are written by a python script from a template
|
756 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
757 |
+
*.manifest
|
758 |
+
*.spec
|
759 |
+
|
760 |
+
# Installer logs
|
761 |
+
pip-log.txt
|
762 |
+
pip-delete-this-directory.txt
|
763 |
+
|
764 |
+
# Unit test / coverage reports
|
765 |
+
htmlcov/
|
766 |
+
.tox/
|
767 |
+
.nox/
|
768 |
+
.coverage
|
769 |
+
.coverage.*
|
770 |
+
.cache
|
771 |
+
nosetests.xml
|
772 |
+
coverage.xml
|
773 |
+
*.cover
|
774 |
+
*.py,cover
|
775 |
+
.hypothesis/
|
776 |
+
.pytest_cache/
|
777 |
+
cover/
|
778 |
+
|
779 |
+
# Translations
|
780 |
+
*.mo
|
781 |
+
*.pot
|
782 |
+
|
783 |
+
# Django stuff:
|
784 |
+
local_settings.py
|
785 |
+
db.sqlite3
|
786 |
+
db.sqlite3-journal
|
787 |
+
|
788 |
+
# Flask stuff:
|
789 |
+
instance/
|
790 |
+
.webassets-cache
|
791 |
+
|
792 |
+
# Scrapy stuff:
|
793 |
+
.scrapy
|
794 |
+
|
795 |
+
# Sphinx documentation
|
796 |
+
docs/_build/
|
797 |
+
|
798 |
+
# PyBuilder
|
799 |
+
.pybuilder/
|
800 |
+
target/
|
801 |
+
|
802 |
+
# Jupyter Notebook
|
803 |
+
.ipynb_checkpoints
|
804 |
+
|
805 |
+
# IPython
|
806 |
+
profile_default/
|
807 |
+
ipython_config.py
|
808 |
+
|
809 |
+
# pyenv
|
810 |
+
# For a library or package, you might want to ignore these files since the code is
|
811 |
+
# intended to run in multiple environments; otherwise, check them in:
|
812 |
+
# .python-version
|
813 |
+
|
814 |
+
# pipenv
|
815 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
816 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
817 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
818 |
+
# install all needed dependencies.
|
819 |
+
#Pipfile.lock
|
820 |
+
|
821 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
822 |
+
__pypackages__/
|
823 |
+
|
824 |
+
# Celery stuff
|
825 |
+
celerybeat-schedule
|
826 |
+
celerybeat.pid
|
827 |
+
|
828 |
+
# SageMath parsed files
|
829 |
+
*.sage.py
|
830 |
+
|
831 |
+
# Environments
|
832 |
+
.env
|
833 |
+
.venv
|
834 |
+
venv/
|
835 |
+
env.bak/
|
836 |
+
venv.bak/
|
837 |
+
|
838 |
+
# Spyder project settings
|
839 |
+
.spyderproject
|
840 |
+
.spyproject
|
841 |
+
|
842 |
+
# Rope project settings
|
843 |
+
.ropeproject
|
844 |
+
|
845 |
+
# mkdocs documentation
|
846 |
+
/site
|
847 |
+
|
848 |
+
# mypy
|
849 |
+
.mypy_cache/
|
850 |
+
.dmypy.json
|
851 |
+
dmypy.json
|
852 |
+
|
853 |
+
# Pyre type checker
|
854 |
+
.pyre/
|
855 |
+
|
856 |
+
# pytype static type analyzer
|
857 |
+
.pytype/
|
858 |
+
|
859 |
+
# Cython debug symbols
|
860 |
+
cython_debug/
|
861 |
+
|
862 |
+
### Backup template
|
863 |
+
*.gho
|
864 |
+
*.ori
|
865 |
+
*.orig
|
866 |
+
|
867 |
+
### Node template
|
868 |
+
# Logs
|
869 |
+
logs
|
870 |
+
npm-debug.log*
|
871 |
+
yarn-debug.log*
|
872 |
+
yarn-error.log*
|
873 |
+
lerna-debug.log*
|
874 |
+
|
875 |
+
# Diagnostic reports (https://nodejs.org/api/report.html)
|
876 |
+
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
877 |
+
|
878 |
+
# Runtime data
|
879 |
+
pids
|
880 |
+
*.pid
|
881 |
+
*.seed
|
882 |
+
*.pid.lock
|
883 |
+
|
884 |
+
# Directory for instrumented libs generated by jscoverage/JSCover
|
885 |
+
lib-cov
|
886 |
+
|
887 |
+
# Coverage directory used by tools like istanbul
|
888 |
+
coverage
|
889 |
+
*.lcov
|
890 |
+
|
891 |
+
# nyc test coverage
|
892 |
+
.nyc_output
|
893 |
+
|
894 |
+
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
895 |
+
.grunt
|
896 |
+
|
897 |
+
# Bower dependency directory (https://bower.io/)
|
898 |
+
bower_components
|
899 |
+
|
900 |
+
# node-waf configuration
|
901 |
+
.lock-wscript
|
902 |
+
|
903 |
+
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
904 |
+
build/Release
|
905 |
+
|
906 |
+
# Dependency directories
|
907 |
+
jspm_packages/
|
908 |
+
|
909 |
+
# Snowpack dependency directory (https://snowpack.dev/)
|
910 |
+
web_modules/
|
911 |
+
|
912 |
+
# TypeScript cache
|
913 |
+
*.tsbuildinfo
|
914 |
+
|
915 |
+
# Optional npm cache directory
|
916 |
+
.npm
|
917 |
+
|
918 |
+
# Optional eslint cache
|
919 |
+
.eslintcache
|
920 |
+
|
921 |
+
# Microbundle cache
|
922 |
+
.rpt2_cache/
|
923 |
+
.rts2_cache_cjs/
|
924 |
+
.rts2_cache_es/
|
925 |
+
.rts2_cache_umd/
|
926 |
+
|
927 |
+
# Optional REPL history
|
928 |
+
.node_repl_history
|
929 |
+
|
930 |
+
# Output of 'npm pack'
|
931 |
+
|
932 |
+
# Yarn Integrity file
|
933 |
+
.yarn-integrity
|
934 |
+
|
935 |
+
# dotenv environment variables file
|
936 |
+
.env.test
|
937 |
+
|
938 |
+
# parcel-bundler cache (https://parceljs.org/)
|
939 |
+
.parcel-cache
|
940 |
+
|
941 |
+
# Next.js build output
|
942 |
+
.next
|
943 |
+
out
|
944 |
+
|
945 |
+
# Nuxt.js build / generate output
|
946 |
+
.nuxt
|
947 |
+
dist
|
948 |
+
|
949 |
+
# Gatsby files
|
950 |
+
.cache/
|
951 |
+
# Comment in the public line in if your project uses Gatsby and not Next.js
|
952 |
+
# https://nextjs.org/blog/next-9-1#public-directory-support
|
953 |
+
# public
|
954 |
+
|
955 |
+
# vuepress build output
|
956 |
+
.vuepress/dist
|
957 |
+
|
958 |
+
# Serverless directories
|
959 |
+
.serverless/
|
960 |
+
|
961 |
+
# FuseBox cache
|
962 |
+
.fusebox/
|
963 |
+
|
964 |
+
# DynamoDB Local files
|
965 |
+
.dynamodb/
|
966 |
+
|
967 |
+
# TernJS port file
|
968 |
+
.tern-port
|
969 |
+
|
970 |
+
# Stores VSCode versions used for testing VSCode extensions
|
971 |
+
.vscode-test
|
972 |
+
|
973 |
+
# yarn v2
|
974 |
+
.yarn/cache
|
975 |
+
.yarn/unplugged
|
976 |
+
.yarn/build-state.yml
|
977 |
+
.yarn/install-state.gz
|
978 |
+
.pnp.*
|
979 |
+
|
980 |
+
### VirtualEnv template
|
981 |
+
# Virtualenv
|
982 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
983 |
+
[Bb]in
|
984 |
+
[Ii]nclude
|
985 |
+
[Ll]ib
|
986 |
+
[Ll]ib64
|
987 |
+
[Ll]ocal
|
988 |
+
pyvenv.cfg
|
989 |
+
pip-selfcheck.json
|
990 |
+
|
991 |
+
### macOS template
|
992 |
+
# General
|
993 |
+
.DS_Store
|
994 |
+
.AppleDouble
|
995 |
+
.LSOverride
|
996 |
+
|
997 |
+
# Icon must end with two \r
|
998 |
+
Icon
|
999 |
+
|
1000 |
+
# Thumbnails
|
1001 |
+
._*
|
1002 |
+
|
1003 |
+
# Files that might appear in the root of a volume
|
1004 |
+
.DocumentRevisions-V100
|
1005 |
+
.fseventsd
|
1006 |
+
.Spotlight-V100
|
1007 |
+
.TemporaryItems
|
1008 |
+
.Trashes
|
1009 |
+
.VolumeIcon.icns
|
1010 |
+
.com.apple.timemachine.donotpresent
|
1011 |
+
|
1012 |
+
# Directories potentially created on remote AFP share
|
1013 |
+
.AppleDB
|
1014 |
+
.AppleDesktop
|
1015 |
+
Network Trash Folder
|
1016 |
+
Temporary Items
|
1017 |
+
.apdisk
|
1018 |
+
|
1019 |
+
### Go template
|
1020 |
+
# Binaries for programs and plugins
|
1021 |
+
*.exe~
|
1022 |
+
*.dll
|
1023 |
+
*.dylib
|
1024 |
+
|
1025 |
+
# Test binary, built with `go test -c`
|
1026 |
+
*.test
|
1027 |
+
|
1028 |
+
# Output of the go coverage tool, specifically when used with LiteIDE
|
1029 |
+
*.out
|
1030 |
+
|
1031 |
+
# Dependency directories (remove the comment below to include it)
|
1032 |
+
# vendor/
|
1033 |
+
|
1034 |
+
### C template
|
1035 |
+
# Prerequisites
|
1036 |
+
*.d
|
1037 |
+
|
1038 |
+
# Object files
|
1039 |
+
*.o
|
1040 |
+
*.ko
|
1041 |
+
*.elf
|
1042 |
+
|
1043 |
+
# Linker output
|
1044 |
+
*.map
|
1045 |
+
*.exp
|
1046 |
+
|
1047 |
+
# Precompiled Headers
|
1048 |
+
*.gch
|
1049 |
+
|
1050 |
+
# Libraries
|
1051 |
+
*.lib
|
1052 |
+
*.a
|
1053 |
+
*.la
|
1054 |
+
*.lo
|
1055 |
+
|
1056 |
+
# Shared objects (inc. Windows DLLs)
|
1057 |
+
*.so.*
|
1058 |
+
|
1059 |
+
# Executables
|
1060 |
+
*.app
|
1061 |
+
*.i*86
|
1062 |
+
*.x86_64
|
1063 |
+
*.hex
|
1064 |
+
|
1065 |
+
# Debug files
|
1066 |
+
*.dSYM/
|
1067 |
+
*.su
|
1068 |
+
*.idb
|
1069 |
+
|
1070 |
+
# Kernel Module Compile Results
|
1071 |
+
*.mod*
|
1072 |
+
*.cmd
|
1073 |
+
.tmp_versions/
|
1074 |
+
modules.order
|
1075 |
+
Module.symvers
|
1076 |
+
Mkfile.old
|
1077 |
+
dkms.conf
|
1078 |
+
|
1079 |
+
### Example user template template
|
1080 |
+
### Example user template
|
1081 |
+
|
1082 |
+
# IntelliJ project files
|
1083 |
+
.idea
|
1084 |
+
*.iml
|
1085 |
+
gen
|
1086 |
+
### TextMate template
|
1087 |
+
*.tmproj
|
1088 |
+
*.tmproject
|
1089 |
+
tmtags
|
1090 |
+
|
1091 |
+
### Anjuta template
|
1092 |
+
# Local configuration folder and symbol database
|
1093 |
+
/.anjuta/
|
1094 |
+
/.anjuta_sym_db.db
|
1095 |
+
|
1096 |
+
### XilinxISE template
|
1097 |
+
# intermediate build files
|
1098 |
+
*.bgn
|
1099 |
+
*.bit
|
1100 |
+
*.bld
|
1101 |
+
*.cmd_log
|
1102 |
+
*.drc
|
1103 |
+
*.ll
|
1104 |
+
*.lso
|
1105 |
+
*.msd
|
1106 |
+
*.msk
|
1107 |
+
*.ncd
|
1108 |
+
*.ngc
|
1109 |
+
*.ngd
|
1110 |
+
*.ngr
|
1111 |
+
*.pad
|
1112 |
+
*.par
|
1113 |
+
*.pcf
|
1114 |
+
*.prj
|
1115 |
+
*.ptwx
|
1116 |
+
*.rbb
|
1117 |
+
*.rbd
|
1118 |
+
*.stx
|
1119 |
+
*.syr
|
1120 |
+
*.twr
|
1121 |
+
*.twx
|
1122 |
+
*.unroutes
|
1123 |
+
*.ut
|
1124 |
+
*.xpi
|
1125 |
+
*.xst
|
1126 |
+
*_bitgen.xwbt
|
1127 |
+
*_envsettings.html
|
1128 |
+
*_map.map
|
1129 |
+
*_map.mrp
|
1130 |
+
*_map.ngm
|
1131 |
+
*_map.xrpt
|
1132 |
+
*_ngdbuild.xrpt
|
1133 |
+
*_pad.csv
|
1134 |
+
*_pad.txt
|
1135 |
+
*_par.xrpt
|
1136 |
+
*_summary.html
|
1137 |
+
*_summary.xml
|
1138 |
+
*_usage.xml
|
1139 |
+
*_xst.xrpt
|
1140 |
+
|
1141 |
+
# iMPACT generated files
|
1142 |
+
_impactbatch.log
|
1143 |
+
impact.xsl
|
1144 |
+
impact_impact.xwbt
|
1145 |
+
ise_impact.cmd
|
1146 |
+
webtalk_impact.xml
|
1147 |
+
|
1148 |
+
# Core Generator generated files
|
1149 |
+
xaw2verilog.log
|
1150 |
+
|
1151 |
+
# project-wide generated files
|
1152 |
+
*.gise
|
1153 |
+
par_usage_statistics.html
|
1154 |
+
usage_statistics_webtalk.html
|
1155 |
+
webtalk.log
|
1156 |
+
webtalk_pn.xml
|
1157 |
+
|
1158 |
+
# generated folders
|
1159 |
+
iseconfig/
|
1160 |
+
xlnx_auto_0_xdb/
|
1161 |
+
xst/
|
1162 |
+
_ngo/
|
1163 |
+
_xmsgs/
|
1164 |
+
|
1165 |
+
### TortoiseGit template
|
1166 |
+
# Project-level settings
|
1167 |
+
/.tgitconfig
|
1168 |
+
|
1169 |
+
### C++ template
|
1170 |
+
# Prerequisites
|
1171 |
+
|
1172 |
+
# Compiled Object files
|
1173 |
+
*.slo
|
1174 |
+
|
1175 |
+
# Precompiled Headers
|
1176 |
+
|
1177 |
+
# Compiled Dynamic libraries
|
1178 |
+
|
1179 |
+
# Fortran module files
|
1180 |
+
*.mod
|
1181 |
+
*.smod
|
1182 |
+
|
1183 |
+
# Compiled Static libraries
|
1184 |
+
*.lai
|
1185 |
+
|
1186 |
+
# Executables
|
1187 |
+
|
1188 |
+
### SublimeText template
|
1189 |
+
# Cache files for Sublime Text
|
1190 |
+
*.tmlanguage.cache
|
1191 |
+
*.tmPreferences.cache
|
1192 |
+
*.stTheme.cache
|
1193 |
+
|
1194 |
+
# Workspace files are user-specific
|
1195 |
+
*.sublime-workspace
|
1196 |
+
|
1197 |
+
# Project files should be checked into the repository, unless a significant
|
1198 |
+
# proportion of contributors will probably not be using Sublime Text
|
1199 |
+
# *.sublime-project
|
1200 |
+
|
1201 |
+
# SFTP configuration file
|
1202 |
+
sftp-config.json
|
1203 |
+
sftp-config-alt*.json
|
1204 |
+
|
1205 |
+
# Package control specific files
|
1206 |
+
Package Control.last-run
|
1207 |
+
Package Control.ca-list
|
1208 |
+
Package Control.ca-bundle
|
1209 |
+
Package Control.system-ca-bundle
|
1210 |
+
Package Control.cache/
|
1211 |
+
Package Control.ca-certs/
|
1212 |
+
Package Control.merged-ca-bundle
|
1213 |
+
Package Control.user-ca-bundle
|
1214 |
+
oscrypto-ca-bundle.crt
|
1215 |
+
bh_unicode_properties.cache
|
1216 |
+
|
1217 |
+
# Sublime-github package stores a github token in this file
|
1218 |
+
# https://packagecontrol.io/packages/sublime-github
|
1219 |
+
GitHub.sublime-settings
|
1220 |
+
|
1221 |
+
### Vim template
|
1222 |
+
# Swap
|
1223 |
+
[._]*.s[a-v][a-z]
|
1224 |
+
!*.svg # comment out if you don't need vector files
|
1225 |
+
[._]*.sw[a-p]
|
1226 |
+
[._]s[a-rt-v][a-z]
|
1227 |
+
[._]ss[a-gi-z]
|
1228 |
+
[._]sw[a-p]
|
1229 |
+
|
1230 |
+
# Session
|
1231 |
+
Session.vim
|
1232 |
+
Sessionx.vim
|
1233 |
+
|
1234 |
+
# Temporary
|
1235 |
+
.netrwhist
|
1236 |
+
# Auto-generated tag files
|
1237 |
+
tags
|
1238 |
+
# Persistent undo
|
1239 |
+
[._]*.un~
|
1240 |
+
|
1241 |
+
### Autotools template
|
1242 |
+
# http://www.gnu.org/software/automake
|
1243 |
+
|
1244 |
+
Makefile.in
|
1245 |
+
/ar-lib
|
1246 |
+
/mdate-sh
|
1247 |
+
/py-compile
|
1248 |
+
/test-driver
|
1249 |
+
/ylwrap
|
1250 |
+
.deps/
|
1251 |
+
.dirstamp
|
1252 |
+
|
1253 |
+
# http://www.gnu.org/software/autoconf
|
1254 |
+
|
1255 |
+
autom4te.cache
|
1256 |
+
/autoscan.log
|
1257 |
+
/autoscan-*.log
|
1258 |
+
/aclocal.m4
|
1259 |
+
/compile
|
1260 |
+
/config.guess
|
1261 |
+
/config.h.in
|
1262 |
+
/config.log
|
1263 |
+
/config.status
|
1264 |
+
/config.sub
|
1265 |
+
/configure
|
1266 |
+
/configure.scan
|
1267 |
+
/depcomp
|
1268 |
+
/install-sh
|
1269 |
+
/missing
|
1270 |
+
/stamp-h1
|
1271 |
+
|
1272 |
+
# https://www.gnu.org/software/libtool/
|
1273 |
+
|
1274 |
+
/ltmain.sh
|
1275 |
+
|
1276 |
+
# http://www.gnu.org/software/texinfo
|
1277 |
+
|
1278 |
+
/texinfo.tex
|
1279 |
+
|
1280 |
+
# http://www.gnu.org/software/m4/
|
1281 |
+
|
1282 |
+
m4/libtool.m4
|
1283 |
+
m4/ltoptions.m4
|
1284 |
+
m4/ltsugar.m4
|
1285 |
+
m4/ltversion.m4
|
1286 |
+
m4/lt~obsolete.m4
|
1287 |
+
|
1288 |
+
# Generated Makefile
|
1289 |
+
# (meta build system like autotools,
|
1290 |
+
# can automatically generate from config.status script
|
1291 |
+
# (which is called by configure script))
|
1292 |
+
|
1293 |
+
### Lua template
|
1294 |
+
# Compiled Lua sources
|
1295 |
+
luac.out
|
1296 |
+
|
1297 |
+
# luarocks build files
|
1298 |
+
*.src.rock
|
1299 |
+
*.tar.gz
|
1300 |
+
|
1301 |
+
# Object files
|
1302 |
+
*.os
|
1303 |
+
|
1304 |
+
# Precompiled Headers
|
1305 |
+
|
1306 |
+
# Libraries
|
1307 |
+
*.def
|
1308 |
+
|
1309 |
+
# Shared objects (inc. Windows DLLs)
|
1310 |
+
|
1311 |
+
# Executables
|
1312 |
+
|
1313 |
+
|
1314 |
+
### Vagrant template
|
1315 |
+
# General
|
1316 |
+
.vagrant/
|
1317 |
+
|
1318 |
+
# Log files (if you are creating logs in debug mode, uncomment this)
|
1319 |
+
# *.log
|
1320 |
+
|
1321 |
+
### Xcode template
|
1322 |
+
# Xcode
|
1323 |
+
#
|
1324 |
+
# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
|
1325 |
+
|
1326 |
+
## User settings
|
1327 |
+
xcuserdata/
|
1328 |
+
|
1329 |
+
## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
|
1330 |
+
*.xcscmblueprint
|
1331 |
+
*.xccheckout
|
1332 |
+
|
1333 |
+
## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
|
1334 |
+
DerivedData/
|
1335 |
+
*.moved-aside
|
1336 |
+
*.pbxuser
|
1337 |
+
!default.pbxuser
|
1338 |
+
*.mode1v3
|
1339 |
+
!default.mode1v3
|
1340 |
+
*.mode2v3
|
1341 |
+
!default.mode2v3
|
1342 |
+
*.perspectivev3
|
1343 |
+
!default.perspectivev3
|
1344 |
+
|
1345 |
+
## Gcc Patch
|
1346 |
+
/*.gcno
|
1347 |
+
|
1348 |
+
### Linux template
|
1349 |
+
|
1350 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
1351 |
+
.fuse_hidden*
|
1352 |
+
|
1353 |
+
# KDE directory preferences
|
1354 |
+
.directory
|
1355 |
+
|
1356 |
+
# Linux trash folder which might appear on any partition or disk
|
1357 |
+
.Trash-*
|
1358 |
+
|
1359 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
1360 |
+
.nfs*
|
1361 |
+
|
1362 |
+
### GitBook template
|
1363 |
+
# Node rules:
|
1364 |
+
## Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
|
1365 |
+
|
1366 |
+
## Dependency directory
|
1367 |
+
## Commenting this out is preferred by some people, see
|
1368 |
+
## https://docs.npmjs.com/misc/faq#should-i-check-my-node_modules-folder-into-git
|
1369 |
+
node_modules
|
1370 |
+
|
1371 |
+
# Book build output
|
1372 |
+
_book
|
1373 |
+
|
1374 |
+
# eBook build output
|
1375 |
+
*.epub
|
1376 |
+
*.mobi
|
1377 |
+
|
1378 |
+
### CodeSniffer template
|
1379 |
+
# gitignore for the PHP Codesniffer framework
|
1380 |
+
# website: https://github.com/squizlabs/PHP_CodeSniffer
|
1381 |
+
#
|
1382 |
+
# Recommended template: PHP.gitignore
|
1383 |
+
|
1384 |
+
/wpcs/*
|
1385 |
+
|
1386 |
+
### PuTTY template
|
1387 |
+
# Private key
|
1388 |
+
*.ppk
|
1389 |
+
*_pb2.py
|
1390 |
+
*.pth
|
1391 |
+
*.pth.tar
|
1392 |
+
*.pt
|
1393 |
+
*.npy
|
1394 |
+
__pycache__
|
1395 |
+
*.egg-info
|
1396 |
+
experiment_config.yaml
|
1397 |
+
api-log/
|
1398 |
+
log/
|
1399 |
+
htmlcov
|
1400 |
+
*.lock
|
1401 |
+
.coverage*
|
1402 |
+
/test_*
|
1403 |
+
.python-version
|
1404 |
+
/name.txt
|
1405 |
+
/summary_log
|
1406 |
+
policy_*
|
1407 |
+
/data
|
1408 |
+
.vscode
|
1409 |
+
formatted_*
|
1410 |
+
**/exp
|
1411 |
+
**/benchmark
|
1412 |
+
**/model_zoo
|
1413 |
+
*ckpt*
|
1414 |
+
log*
|
1415 |
+
*.puml.png
|
1416 |
+
*.puml.eps
|
1417 |
+
*.puml.svg
|
1418 |
+
default*
|
1419 |
+
events.*
|
1420 |
+
|
1421 |
+
# DI-engine special key
|
1422 |
+
*default_logger.txt
|
1423 |
+
*default_tb_logger
|
1424 |
+
*evaluate.txt
|
1425 |
+
*total_config.py
|
1426 |
+
eval_config.py
|
1427 |
+
collect_demo_data_config.py
|
1428 |
+
!ding/**/*.py
|
1429 |
+
events.*
|
1430 |
+
|
1431 |
+
evogym/*
|
DI-engine/.style.yapf
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[style]
|
2 |
+
# For explanation and more information: https://github.com/google/yapf
|
3 |
+
BASED_ON_STYLE=pep8
|
4 |
+
DEDENT_CLOSING_BRACKETS=True
|
5 |
+
SPLIT_BEFORE_FIRST_ARGUMENT=True
|
6 |
+
ALLOW_SPLIT_BEFORE_DICT_VALUE=False
|
7 |
+
JOIN_MULTIPLE_LINES=False
|
8 |
+
COLUMN_LIMIT=120
|
9 |
+
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True
|
10 |
+
BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2
|
11 |
+
SPACES_AROUND_POWER_OPERATOR=True
|
DI-engine/CHANGELOG
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2023.11.06(v0.5.0)
|
2 |
+
- env: add tabmwp env (#667)
|
3 |
+
- env: polish anytrading env issues (#731)
|
4 |
+
- algo: add PromptPG algorithm (#667)
|
5 |
+
- algo: add Plan Diffuser algorithm (#700)
|
6 |
+
- algo: add new pipeline implementation of IMPALA algorithm (#713)
|
7 |
+
- algo: add dropout layers to DQN-style algorithms (#712)
|
8 |
+
- feature: add new pipeline agent for sac/ddpg/a2c/ppo and Hugging Face support (#637) (#730) (#737)
|
9 |
+
- feature: add more unittest cases for model (#728)
|
10 |
+
- feature: add collector logging in new pipeline (#735)
|
11 |
+
- fix: logger middleware problems (#715)
|
12 |
+
- fix: ppo parallel bug (#709)
|
13 |
+
- fix: typo in optimizer_helper.py (#726)
|
14 |
+
- fix: mlp dropout if condition bug
|
15 |
+
- fix: drex collecting data unittest bugs
|
16 |
+
- style: polish env manager/wrapper comments and API doc (#742)
|
17 |
+
- style: polish model comments and API doc (#722) (#729) (#734) (#736) (#741)
|
18 |
+
- style: polish policy comments and API doc (#732)
|
19 |
+
- style: polish rl_utils comments and API doc (#724)
|
20 |
+
- style: polish torch_utils comments and API doc (#738)
|
21 |
+
- style: update README.md and Colab demo (#733)
|
22 |
+
- style: update metaworld docker image
|
23 |
+
|
24 |
+
2023.08.23(v0.4.9)
|
25 |
+
- env: add cliffwalking env (#677)
|
26 |
+
- env: add lunarlander ppo config and example
|
27 |
+
- algo: add BCQ offline RL algorithm (#640)
|
28 |
+
- algo: add Dreamerv3 model-based RL algorithm (#652)
|
29 |
+
- algo: add tensor stream merge network tools (#673)
|
30 |
+
- algo: add scatter connection model (#680)
|
31 |
+
- algo: refactor Decision Transformer in new pipeline and support img input and discrete output (#693)
|
32 |
+
- algo: add three variants of Bilinear classes and a FiLM class (#703)
|
33 |
+
- feature: polish offpolicy RL multi-gpu DDP training (#679)
|
34 |
+
- feature: add middleware for Ape-X distributed pipeline (#696)
|
35 |
+
- feature: add example for evaluating trained DQN (#706)
|
36 |
+
- fix: to_ndarray fails to assign dtype for scalars (#708)
|
37 |
+
- fix: evaluator return episode_info compatibility bug
|
38 |
+
- fix: cql example entry wrong config bug
|
39 |
+
- fix: enable_save_figure env interface
|
40 |
+
- fix: redundant env info bug in evaluator
|
41 |
+
- fix: to_item unittest bug
|
42 |
+
- style: polish and simplify requirements (#672)
|
43 |
+
- style: add Hugging Face Model Zoo badge (#674)
|
44 |
+
- style: add openxlab Model Zoo badge (#675)
|
45 |
+
- style: fix py37 macos ci bug and update default pytorch from 1.7.1 to 1.12.1 (#678)
|
46 |
+
- style: fix mujoco-py compatibility issue for cython<3 (#711)
|
47 |
+
- style: fix type spell error (#704)
|
48 |
+
- style: fix pypi release actions ubuntu 18.04 bug
|
49 |
+
- style: update contact information (e.g. wechat)
|
50 |
+
- style: polish algorithm doc tables
|
51 |
+
|
52 |
+
2023.05.25(v0.4.8)
|
53 |
+
- env: fix gym hybrid reward dtype bug (#664)
|
54 |
+
- env: fix atari env id noframeskip bug (#655)
|
55 |
+
- env: fix typo in gym any_trading env (#654)
|
56 |
+
- env: update td3bc d4rl config (#659)
|
57 |
+
- env: polish bipedalwalker config
|
58 |
+
- algo: add EDAC offline RL algorithm (#639)
|
59 |
+
- algo: add LN and GN norm_type support in ResBlock (#660)
|
60 |
+
- algo: add normal value norm baseline for PPOF (#658)
|
61 |
+
- algo: polish last layer init/norm in MLP (#650)
|
62 |
+
- algo: polish TD3 monitor variable
|
63 |
+
- feature: add MAPPO/MASAC task example (#661)
|
64 |
+
- feature: add PPO example for complex env observation (#644)
|
65 |
+
- feature: add barrier middleware (#570)
|
66 |
+
- fix: abnormal collector log and add record_random_collect option (#662)
|
67 |
+
- fix: to_item compatibility bug (#646)
|
68 |
+
- fix: trainer dtype transform compatibility bug
|
69 |
+
- fix: pettingzoo 1.23.0 compatibility bug
|
70 |
+
- fix: ensemble head unittest bug
|
71 |
+
- style: fix incompatible gym version bug in Dockerfile.env (#653)
|
72 |
+
- style: add more algorithm docs
|
73 |
+
|
74 |
+
2023.04.11(v0.4.7)
|
75 |
+
- env: add dmc2gym env support and baseline (#451)
|
76 |
+
- env: update pettingzoo to the latest version (#597)
|
77 |
+
- env: polish icm/rnd+onppo config bugs and add app_door_to_key env (#564)
|
78 |
+
- env: add lunarlander continuous TD3/SAC config
|
79 |
+
- env: polish lunarlander discrete C51 config
|
80 |
+
- algo: add Procedure Cloning (PC) imitation learning algorithm (#514)
|
81 |
+
- algo: add Munchausen Reinforcement Learning (MDQN) algorithm (#590)
|
82 |
+
- algo: add reward/value norm methods: popart & value rescale & symlog (#605)
|
83 |
+
- algo: polish reward model config and training pipeline (#624)
|
84 |
+
- algo: add PPOF reward space demo support (#608)
|
85 |
+
- algo: add PPOF Atari demo support (#589)
|
86 |
+
- algo: polish dqn default config and env examples (#611)
|
87 |
+
- algo: polish comment and clean code about SAC
|
88 |
+
- feature: add language model (e.g. GPT) training utils (#625)
|
89 |
+
- feature: remove policy cfg sub fields requirements (#620)
|
90 |
+
- feature: add full wandb support (#579)
|
91 |
+
- fix: confusing shallow copy operation about next_obs (#641)
|
92 |
+
- fix: unsqueeze action_args in PDQN when shape is 1 (#599)
|
93 |
+
- fix: evaluator return_info tensor type bug (#592)
|
94 |
+
- fix: deque buffer wrapper PER bug (#586)
|
95 |
+
- fix: reward model save method compatibility bug
|
96 |
+
- fix: logger assertion and unittest bug
|
97 |
+
- fix: bfs test py3.9 compatibility bug
|
98 |
+
- fix: zergling collector unittest bug
|
99 |
+
- style: add DI-engine torch-rpc p2p communication docker (#628)
|
100 |
+
- style: add D4RL docker (#591)
|
101 |
+
- style: correct typo in task (#617)
|
102 |
+
- style: correct typo in time_helper (#602)
|
103 |
+
- style: polish readme and add treetensor example
|
104 |
+
- style: update contributing doc
|
105 |
+
|
106 |
+
2023.02.16(v0.4.6)
|
107 |
+
- env: add metadrive env and related ppo config (#574)
|
108 |
+
- env: add acrobot env and related dqn config (#577)
|
109 |
+
- env: add carracing in box2d (#575)
|
110 |
+
- env: add new gym hybrid viz (#563)
|
111 |
+
- env: update cartpole IL config (#578)
|
112 |
+
- algo: add BDQ algorithm (#558)
|
113 |
+
- algo: add procedure cloning model (#573)
|
114 |
+
- feature: add simplified PPOF (PPO × Family) interface (#567) (#568) (#581) (#582)
|
115 |
+
- fix: to_device and prev_state bug when using ttorch (#571)
|
116 |
+
- fix: py38 and numpy unittest bugs (#565)
|
117 |
+
- fix: typo in contrastive_loss.py (#572)
|
118 |
+
- fix: dizoo envs pkg installation bugs
|
119 |
+
- fix: multi_trainer middleware unittest bug
|
120 |
+
- style: add evogym docker (#580)
|
121 |
+
- style: fix metaworld docker bug
|
122 |
+
- style: fix setuptools high version incompatibility bug
|
123 |
+
- style: extend treetensor lowest version
|
124 |
+
|
125 |
+
2022.12.13(v0.4.5)
|
126 |
+
- env: add beergame supply chain optimization env (#512)
|
127 |
+
- env: add env gym_pybullet_drones (#526)
|
128 |
+
- env: rename eval reward to episode return (#536)
|
129 |
+
- algo: add policy gradient algo implementation (#544)
|
130 |
+
- algo: add MADDPG algo implementation (#550)
|
131 |
+
- algo: add IMPALA continuous algo implementation (#551)
|
132 |
+
- algo: add MADQN algo implementation (#540)
|
133 |
+
- feature: add new task IMPALA-type distributed training scheme (#321)
|
134 |
+
- feature: add load and save method for replaybuffer (#542)
|
135 |
+
- feature: add more DingEnvWrapper example (#525)
|
136 |
+
- feature: add evaluator more info viz support (#538)
|
137 |
+
- feature: add trackback log for subprocess env manager (#534)
|
138 |
+
- fix: halfcheetah td3 config file (#537)
|
139 |
+
- fix: mujoco action_clip args compatibility bug (#535)
|
140 |
+
- fix: atari a2c config entry bug
|
141 |
+
- fix: drex unittest compatibility bug
|
142 |
+
- style: add Roadmap issue of DI-engine (#548)
|
143 |
+
- style: update related project link and new env doc
|
144 |
+
|
145 |
+
2022.10.31(v0.4.4)
|
146 |
+
- env: add modified gym-hybrid including moving, sliding and hardmove (#505) (#519)
|
147 |
+
- env: add evogym support (#495) (#527)
|
148 |
+
- env: add save_replay_gif option (#506)
|
149 |
+
- env: adapt minigrid_env and related config to latest MiniGrid v2.0.0 (#500)
|
150 |
+
- algo: add pcgrad optimizer (#489)
|
151 |
+
- algo: add some features in MLP and ResBlock (#511)
|
152 |
+
- algo: delete mcts related modules (#518)
|
153 |
+
- feature: add wandb middleware and demo (#488) (#523) (#528)
|
154 |
+
- feature: add new properties in Context (#499)
|
155 |
+
- feature: add single env policy wrapper for policy deployment
|
156 |
+
- feature: add custom model demo and doc
|
157 |
+
- fix: build logger args and unittests (#522)
|
158 |
+
- fix: total_loss calculation in PDQN (#504)
|
159 |
+
- fix: save gif function bug
|
160 |
+
- fix: level sample unittest bug
|
161 |
+
- style: update contact email address (#503)
|
162 |
+
- style: polish env log and resblock name
|
163 |
+
- style: add details button in readme
|
164 |
+
|
165 |
+
2022.09.23(v0.4.3)
|
166 |
+
- env: add rule-based gomoku expert (#465)
|
167 |
+
- algo: fix a2c policy batch size bug (#481)
|
168 |
+
- algo: enable activation option in collaq attention and mixer
|
169 |
+
- algo: minor fix about IBC (#477)
|
170 |
+
- feature: add IGM support (#486)
|
171 |
+
- feature: add tb logger middleware and demo
|
172 |
+
- fix: the type conversion in ding_env_wrapper (#483)
|
173 |
+
- fix: di-orchestrator version bug in unittest (#479)
|
174 |
+
- fix: data collection errors caused by shallow copies (#475)
|
175 |
+
- fix: gym==0.26.0 seed args bug
|
176 |
+
- style: add readme tutorial link(environment & algorithm) (#490) (#493)
|
177 |
+
- style: adjust location of the default_model method in policy (#453)
|
178 |
+
|
179 |
+
2022.09.08(v0.4.2)
|
180 |
+
- env: add rocket env (#449)
|
181 |
+
- env: updated pettingzoo env and improved related performance (#457)
|
182 |
+
- env: add mario env demo (#443)
|
183 |
+
- env: add MAPPO multi-agent config (#464)
|
184 |
+
- env: add mountain car (discrete action) environment (#452)
|
185 |
+
- env: fix multi-agent mujoco gym comaptibility bug
|
186 |
+
- env: fix gfootball env save_replay variable init bug
|
187 |
+
- algo: add IBC (Implicit Behaviour Cloning) algorithm (#401)
|
188 |
+
- algo: add BCO (Behaviour Cloning from Observation) algorithm (#270)
|
189 |
+
- algo: add continuous PPOPG algorithm (#414)
|
190 |
+
- algo: add PER in CollaQ (#472)
|
191 |
+
- algo: add activation option in QMIX and CollaQ
|
192 |
+
- feature: update ctx to dataclass (#467)
|
193 |
+
- fix: base_env FinalMeta bug about gym 0.25.0-0.25.1
|
194 |
+
- fix: config inplace modification bug
|
195 |
+
- fix: ding cli no argument problem
|
196 |
+
- fix: import errors after running setup.py (jinja2, markupsafe)
|
197 |
+
- fix: conda py3.6 and cross platform build bug
|
198 |
+
- style: add project state and datetime in log dir (#455)
|
199 |
+
- style: polish notes for q-learning model (#427)
|
200 |
+
- style: revision to mujoco dockerfile and validation (#474)
|
201 |
+
- style: add dockerfile for cityflow env
|
202 |
+
- style: polish default output log format
|
203 |
+
|
204 |
+
2022.08.12(v0.4.1)
|
205 |
+
- env: add gym trading env (#424)
|
206 |
+
- env: add board games env (tictactoe, gomuku, chess) (#356)
|
207 |
+
- env: add sokoban env (#397) (#429)
|
208 |
+
- env: add BC and DQN demo for gfootball (#418) (#423)
|
209 |
+
- env: add discrete pendulum env (#395)
|
210 |
+
- algo: add STEVE model-based algorithm (#363)
|
211 |
+
- algo: add PLR algorithm (#408)
|
212 |
+
- algo: plugin ST-DIM in PPO (#379)
|
213 |
+
- feature: add final result saving in training pipeline
|
214 |
+
- fix: random policy randomness bug
|
215 |
+
- fix: action_space seed compalbility bug
|
216 |
+
- fix: discard message sent by self in redis mq (#354)
|
217 |
+
- fix: remove pace controller (#400)
|
218 |
+
- fix: import error in serial_pipeline_trex (#410)
|
219 |
+
- fix: unittest hang and fail bug (#413)
|
220 |
+
- fix: DREX collect data unittest bug
|
221 |
+
- fix: remove unused import cv2
|
222 |
+
- fix: ding CLI env/policy option bug
|
223 |
+
- style: upgrade Python version from 3.6-3.8 to 3.7-3.9
|
224 |
+
- style: upgrade gym version from 0.20.0 to 0.25.0
|
225 |
+
- style: upgrade torch version from 1.10.0 to 1.12.0
|
226 |
+
- style: upgrade mujoco bin from 2.0.0 to 2.1.0
|
227 |
+
- style: add buffer api description (#371)
|
228 |
+
- style: polish VAE comments (#404)
|
229 |
+
- style: unittest for FQF (#412)
|
230 |
+
- style: add metaworld dockerfile (#432)
|
231 |
+
- style: remove opencv requirement in default setting
|
232 |
+
- style: update long description in setup.py
|
233 |
+
|
234 |
+
2022.06.21(v0.4.0)
|
235 |
+
- env: add MAPPO/MASAC all configs in SMAC (#310) **(SOTA results in SMAC!!!)**
|
236 |
+
- env: add dmc2gym env (#344) (#360)
|
237 |
+
- env: remove DI-star requirements of dizoo/smac, use official pysc2 (#302)
|
238 |
+
- env: add latest GAIL mujoco config (#298)
|
239 |
+
- env: polish procgen env (#311)
|
240 |
+
- env: add MBPO ant and humanoid config for mbpo (#314)
|
241 |
+
- env: fix slime volley env obs space bug when agent_vs_agent
|
242 |
+
- env: fix smac env obs space bug
|
243 |
+
- env: fix import path error in lunarlander (#362)
|
244 |
+
- algo: add Decision Transformer algorithm (#327) (#364)
|
245 |
+
- algo: add on-policy PPG algorithm (#312)
|
246 |
+
- algo: add DDPPO & add model-based SAC with lambda-return algorithm (#332)
|
247 |
+
- algo: add infoNCE loss and ST-DIM algorithm (#326)
|
248 |
+
- algo: add FQF distributional RL algorithm (#274)
|
249 |
+
- algo: add continuous BC algorithm (#318)
|
250 |
+
- algo: add pure policy gradient PPO algorithm (#382)
|
251 |
+
- algo: add SQIL + SAC algorithm (#348)
|
252 |
+
- algo: polish NGU and related modules (#283) (#343) (#353)
|
253 |
+
- algo: add marl distributional td loss (#331)
|
254 |
+
- feature: add new worker middleware (#236)
|
255 |
+
- feature: refactor model-based RL pipeline (ding/world_model) (#332)
|
256 |
+
- feature: refactor logging system in the whole DI-engine (#316)
|
257 |
+
- feature: add env supervisor design (#330)
|
258 |
+
- feature: support async reset for envpool env manager (#250)
|
259 |
+
- feature: add log videos to tensorboard (#320)
|
260 |
+
- feature: refactor impala cnn encoder interface (#378)
|
261 |
+
- fix: env save replay bug
|
262 |
+
- fix: transformer mask inplace operation bug
|
263 |
+
- fix: transtion_with_policy_data bug in SAC and PPG
|
264 |
+
- style: add dockerfile for ding:hpc image (#337)
|
265 |
+
- style: fix mpire 2.3.5 which handles default processes more elegantly (#306)
|
266 |
+
- style: use FORMAT_DIR instead of ./ding (#309)
|
267 |
+
- style: update quickstart colab link (#347)
|
268 |
+
- style: polish comments in ding/model/common (#315)
|
269 |
+
- style: update mujoco docker download path (#386)
|
270 |
+
- style: fix protobuf new version compatibility bug
|
271 |
+
- style: fix torch1.8.0 torch.div compatibility bug
|
272 |
+
- style: update doc links in readme
|
273 |
+
- style: add outline in readme and update wechat image
|
274 |
+
- style: update head image and refactor docker dir
|
275 |
+
|
276 |
+
2022.04.23(v0.3.1)
|
277 |
+
- env: polish and standardize dizoo config (#252) (#255) (#249) (#246) (#262) (#261) (#266) (#273) (#263) (#280) (#259) (#286) (#277) (#290) (#289) (#299)
|
278 |
+
- env: add GRF academic env and config (#281)
|
279 |
+
- env: update env inferface of GRF (#258)
|
280 |
+
- env: update D4RL offline RL env and config (#285)
|
281 |
+
- env: polish PomdpAtariEnv (#254)
|
282 |
+
- algo: DREX algorithm (#218)
|
283 |
+
- feature: separate mq and parallel modules, add redis (#247)
|
284 |
+
- feature: rename env variables; fix attach_to parameter (#244)
|
285 |
+
- feature: env implementation check (#275)
|
286 |
+
- feature: adjust and set the max column number of tabulate in log (#296)
|
287 |
+
- feature: add drop_extra option for sample collect
|
288 |
+
- feature: speed up GTrXL forward method + GRU unittest (#253) (#292)
|
289 |
+
- fix: add act_scale in DingEnvWrapper; fix envpool env manager (#245)
|
290 |
+
- fix: auto_reset=False and env_ref bug in env manager (#248)
|
291 |
+
- fix: data type and deepcopy bug in RND (#288)
|
292 |
+
- fix: share_memory bug and multi_mujoco env (#279)
|
293 |
+
- fix: some bugs in GTrXL (#276)
|
294 |
+
- fix: update gym_vector_env_manager and add more unittest (#241)
|
295 |
+
- fix: mdpolicy random collect bug (#293)
|
296 |
+
- fix: gym.wrapper save video replay bug
|
297 |
+
- fix: collect abnormal step format bug and add unittest
|
298 |
+
- test: add buffer benchmark & socket test (#284)
|
299 |
+
- style: upgrade mpire (#251)
|
300 |
+
- style: add GRF(google research football) docker (#256)
|
301 |
+
- style: update policy and gail comment
|
302 |
+
|
303 |
+
2022.03.24(v0.3.0)
|
304 |
+
- env: add bitfilp HER DQN benchmark (#192) (#193) (#197)
|
305 |
+
- env: slime volley league training demo (#229)
|
306 |
+
- algo: Gated TransformXL (GTrXL) algorithm (#136)
|
307 |
+
- algo: TD3 + VAE(HyAR) latent action algorithm (#152)
|
308 |
+
- algo: stochastic dueling network (#234)
|
309 |
+
- algo: use log prob instead of using prob in ACER (#186)
|
310 |
+
- feature: support envpool env manager (#228)
|
311 |
+
- feature: add league main and other improvements in new framework (#177) (#214)
|
312 |
+
- feature: add pace controller middleware in new framework (#198)
|
313 |
+
- feature: add auto recover option in new framework (#242)
|
314 |
+
- feature: add k8s parser in new framework (#243)
|
315 |
+
- feature: support async event handler and logger (#213)
|
316 |
+
- feautre: add grad norm calculator (#205)
|
317 |
+
- feautre: add gym vector env manager (#147)
|
318 |
+
- feautre: add train_iter and env_step in serial pipeline (#212)
|
319 |
+
- feautre: add rich logger handler (#219) (#223) (#232)
|
320 |
+
- feature: add naive lr_scheduler demo
|
321 |
+
- refactor: new BaseEnv and DingEnvWrapper (#171) (#231) (#240)
|
322 |
+
- polish: MAPPO and MASAC smac config (#209) (#239)
|
323 |
+
- polish: QMIX smac config (#175)
|
324 |
+
- polish: R2D2 atari config (#181)
|
325 |
+
- polish: A2C atari config (#189)
|
326 |
+
- polish: GAIL box2d and mujoco config (#188)
|
327 |
+
- polish: ACER atari config (#180)
|
328 |
+
- polish: SQIL atari config (#230)
|
329 |
+
- polish: TREX atari/mujoco config
|
330 |
+
- polish: IMPALA atari config
|
331 |
+
- polish: MBPO/D4PG mujoco config
|
332 |
+
- fix: random_collect compatible to episode collector (#190)
|
333 |
+
- fix: remove default n_sample/n_episode value in policy config (#185)
|
334 |
+
- fix: PDQN model bug on gpu device (#220)
|
335 |
+
- fix: TREX algorithm CLI bug (#182)
|
336 |
+
- fix: DQfD JE computation bug and move to AdamW optimizer (#191)
|
337 |
+
- fix: pytest problem for parallel middleware (#211)
|
338 |
+
- fix: mujoco numpy compatibility bug
|
339 |
+
- fix: markupsafe 2.1.0 bug
|
340 |
+
- fix: framework parallel module network emit bug
|
341 |
+
- fix: mpire bug and disable algotest in py3.8
|
342 |
+
- fix: lunarlander env import and env_id bug
|
343 |
+
- fix: icm unittest repeat name bug
|
344 |
+
- fix: buffer thruput close bug
|
345 |
+
- test: resnet unittest (#199)
|
346 |
+
- test: SAC/SQN unittest (#207)
|
347 |
+
- test: CQL/R2D3/GAIL unittest (#201)
|
348 |
+
- test: NGU td unittest (#210)
|
349 |
+
- test: model wrapper unittest (#215)
|
350 |
+
- test: MAQAC model unittest (#226)
|
351 |
+
- style: add doc docker (#221)
|
352 |
+
|
353 |
+
2022.01.01(v0.2.3)
|
354 |
+
- env: add multi-agent mujoco env (#146)
|
355 |
+
- env: add delay reward mujoco env (#145)
|
356 |
+
- env: fix port conflict in gym_soccer (#139)
|
357 |
+
- algo: MASAC algorithm (#112)
|
358 |
+
- algo: TREX algorithm (#119) (#144)
|
359 |
+
- algo: H-PPO hybrid action space algorithm (#140)
|
360 |
+
- algo: residual link in R2D2 (#150)
|
361 |
+
- algo: gumbel softmax (#169)
|
362 |
+
- algo: move actor_head_type to action_space field
|
363 |
+
- feature: new main pipeline and async/parallel framework (#142) (#166) (#168)
|
364 |
+
- feature: refactor buffer, separate algorithm and storage (#129)
|
365 |
+
- feature: cli in new pipeline(ditask) (#160)
|
366 |
+
- feature: add multiprocess tblogger, fix circular reference problem (#156)
|
367 |
+
- feature: add multiple seed cli
|
368 |
+
- feature: polish eps_greedy_multinomial_sample in model_wrapper (#154)
|
369 |
+
- fix: R2D3 abs priority problem (#158) (#161)
|
370 |
+
- fix: multi-discrete action space policies random action bug (#167)
|
371 |
+
- fix: doc generate bug with enum_tools (#155)
|
372 |
+
- style: more comments about R2D2 (#149)
|
373 |
+
- style: add doc about how to migrate a new env
|
374 |
+
- style: add doc about env tutorial in dizoo
|
375 |
+
- style: add conda auto release (#148)
|
376 |
+
- style: udpate zh doc link
|
377 |
+
- style: update kaggle tutorial link
|
378 |
+
|
379 |
+
2021.12.03(v0.2.2)
|
380 |
+
- env: apple key to door treasure env (#128)
|
381 |
+
- env: add bsuite memory benchmark (#138)
|
382 |
+
- env: polish atari impala config
|
383 |
+
- algo: Guided Cost IRL algorithm (#57)
|
384 |
+
- algo: ICM exploration algorithm (#41)
|
385 |
+
- algo: MP-DQN hybrid action space algorithm (#131)
|
386 |
+
- algo: add loss statistics and polish r2d3 pong config (#126)
|
387 |
+
- feautre: add renew env mechanism in env manager and update timeout mechanism (#127) (#134)
|
388 |
+
- fix: async subprocess env manager reset bug (#137)
|
389 |
+
- fix: keepdims name bug in model wrapper
|
390 |
+
- fix: on-policy ppo value norm bug
|
391 |
+
- fix: GAE and RND unittest bug
|
392 |
+
- fix: hidden state wrapper h tensor compatiblity
|
393 |
+
- fix: naive buffer auto config create bug
|
394 |
+
- style: add supporters list
|
395 |
+
|
396 |
+
2021.11.22(v0.2.1)
|
397 |
+
- env: gym-hybrid env (#86)
|
398 |
+
- env: gym-soccer (HFO) env (#94)
|
399 |
+
- env: Go-Bigger env baseline (#95)
|
400 |
+
- env: add the bipedalwalker config of sac and ppo (#121)
|
401 |
+
- algo: DQfD Imitation Learning algorithm (#48) (#98)
|
402 |
+
- algo: TD3BC offline RL algorithm (#88)
|
403 |
+
- algo: MBPO model-based RL algorithm (#113)
|
404 |
+
- algo: PADDPG hybrid action space algorithm (#109)
|
405 |
+
- algo: PDQN hybrid action space algorithm (#118)
|
406 |
+
- algo: fix R2D2 bugs and produce benchmark, add naive NGU (#40)
|
407 |
+
- algo: self-play training demo in slime_volley env (#23)
|
408 |
+
- algo: add example of GAIL entry + config for mujoco (#114)
|
409 |
+
- feature: enable arbitrary policy num in serial sample collector
|
410 |
+
- feautre: add torch DataParallel for single machine multi-GPU
|
411 |
+
- feature: add registry force_overwrite argument
|
412 |
+
- feature: add naive buffer periodic thruput seconds argument
|
413 |
+
- test: add pure docker setting test (#103)
|
414 |
+
- test: add unittest for dataset and evaluator (#107)
|
415 |
+
- test: add unittest for on-policy algorithm (#92)
|
416 |
+
- test: add unittest for ppo and td (MARL case) (#89)
|
417 |
+
- test: polish collector benchmark test
|
418 |
+
- fix: target model wrapper hard reset bug
|
419 |
+
- fix: fix learn state_dict target model bug
|
420 |
+
- fix: ppo bugs and update atari ppo offpolicy config (#108)
|
421 |
+
- fix: pyyaml version bug (#99)
|
422 |
+
- fix: small fix on bsuite environment (#117)
|
423 |
+
- fix: discrete cql unittest bug
|
424 |
+
- fix: release workflow bug
|
425 |
+
- fix: base policy model state_dict overlap bug
|
426 |
+
- fix: remove on_policy option in dizoo config and entry
|
427 |
+
- fix: remove torch in env
|
428 |
+
- style: gym version > 0.20.0
|
429 |
+
- style: torch version >= 1.1.0, <= 1.10.0
|
430 |
+
- style: ale-py == 0.7.0
|
431 |
+
|
432 |
+
2021.9.30(v0.2.0)
|
433 |
+
- env: overcooked env (#20)
|
434 |
+
- env: procgen env (#26)
|
435 |
+
- env: modified predator env (#30)
|
436 |
+
- env: d4rl env (#37)
|
437 |
+
- env: imagenet dataset (#27)
|
438 |
+
- env: bsuite env (#58)
|
439 |
+
- env: move atari_py to ale-py
|
440 |
+
- algo: SQIL algorithm (#25) (#44)
|
441 |
+
- algo: CQL algorithm (discrete/continuous) (#37) (#68)
|
442 |
+
- algo: MAPPO algorithm (#62)
|
443 |
+
- algo: WQMIX algorithm (#24)
|
444 |
+
- algo: D4PG algorithm (#76)
|
445 |
+
- algo: update multi discrete policy(dqn, ppo, rainbow) (#51) (#72)
|
446 |
+
- feature: image classification training pipeline (#27)
|
447 |
+
- feature: add force_reproducibility option in subprocess env manager
|
448 |
+
- feature: add/delete/restart replicas via cli for k8s
|
449 |
+
- feautre: add league metric (trueskill and elo) (#22)
|
450 |
+
- feature: add tb in naive buffer and modify tb in advanced buffer (#39)
|
451 |
+
- feature: add k8s launcher and di-orchestrator launcher, add related unittest (#45) (#49)
|
452 |
+
- feature: add hyper-parameter scheduler module (#38)
|
453 |
+
- feautre: add plot function (#59)
|
454 |
+
- fix: acer bug and update atari result (#21)
|
455 |
+
- fix: mappo nan bug and dict obs cannot unsqueeze bug (#54)
|
456 |
+
- fix: r2d2 hidden state and obs arange bug (#36) (#52)
|
457 |
+
- fix: ppo bug when use dual_clip and adv > 0
|
458 |
+
- fix: qmix double_q hidden state bug
|
459 |
+
- fix: spawn context problem in interaction unittest (#69)
|
460 |
+
- fix: formatted config no eval bug (#53)
|
461 |
+
- fix: the catch statments that will never succeed and system proxy bug (#71) (#79)
|
462 |
+
- fix: lunarlander config
|
463 |
+
- fix: c51 head dimension mismatch bug
|
464 |
+
- fix: mujoco config typo bug
|
465 |
+
- fix: ppg atari config bug
|
466 |
+
- fix: max use and priority update special branch bug in advanced_buffer
|
467 |
+
- style: add docker deploy in github workflow (#70) (#78) (#80)
|
468 |
+
- style: support PyTorch 1.9.0
|
469 |
+
- style: add algo/env list in README
|
470 |
+
- style: rename advanced_buffer register name to advanced
|
471 |
+
|
472 |
+
|
473 |
+
2021.8.3(v0.1.1)
|
474 |
+
- env: selfplay/league demo (#12)
|
475 |
+
- env: pybullet env (#16)
|
476 |
+
- env: minigrid env (#13)
|
477 |
+
- env: atari enduro config (#11)
|
478 |
+
- algo: on policy PPO (#9)
|
479 |
+
- algo: ACER algorithm (#14)
|
480 |
+
- feature: polish experiment directory structure (#10)
|
481 |
+
- refactor: split doc to new repo (#4)
|
482 |
+
- fix: atari env info action space bug
|
483 |
+
- fix: env manager retry wrapper raise exception info bug
|
484 |
+
- fix: dist entry disable-flask-log typo
|
485 |
+
- style: codestyle optimization by lgtm (#7)
|
486 |
+
- style: code/comment statistics badge
|
487 |
+
- style: github CI workflow
|
488 |
+
|
489 |
+
2021.7.8(v0.1.0)
|
DI-engine/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributor Covenant Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
8 |
+
identity and expression, level of experience, education, socio-economic status,
|
9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
10 |
+
and orientation.
|
11 |
+
|
12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
13 |
+
diverse, inclusive, and healthy community.
|
14 |
+
|
15 |
+
## Our Standards
|
16 |
+
|
17 |
+
Examples of behavior that contributes to a positive environment for our
|
18 |
+
community include:
|
19 |
+
|
20 |
+
* Demonstrating empathy and kindness toward other people
|
21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
22 |
+
* Giving and gracefully accepting constructive feedback
|
23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
24 |
+
and learning from the experience
|
25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
26 |
+
overall community
|
27 |
+
|
28 |
+
Examples of unacceptable behavior include:
|
29 |
+
|
30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
31 |
+
advances of any kind
|
32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
33 |
+
* Public or private harassment
|
34 |
+
* Publishing others' private information, such as a physical or email
|
35 |
+
address, without their explicit permission
|
36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
37 |
+
professional setting
|
38 |
+
|
39 |
+
## Enforcement Responsibilities
|
40 |
+
|
41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
44 |
+
or harmful.
|
45 |
+
|
46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
49 |
+
decisions when appropriate.
|
50 |
+
|
51 |
+
## Scope
|
52 |
+
|
53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
54 |
+
an individual is officially representing the community in public spaces.
|
55 |
+
Examples of representing our community include using an official e-mail address,
|
56 |
+
posting via an official social media account, or acting as an appointed
|
57 |
+
representative at an online or offline event.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported to the community leaders responsible for enforcement at
|
63 |
+
opendilab.contact@gmail.com.
|
64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
65 |
+
|
66 |
+
All community leaders are obligated to respect the privacy and security of the
|
67 |
+
reporter of any incident.
|
68 |
+
|
69 |
+
## Enforcement Guidelines
|
70 |
+
|
71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
73 |
+
|
74 |
+
### 1. Correction
|
75 |
+
|
76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
77 |
+
unprofessional or unwelcome in the community.
|
78 |
+
|
79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
80 |
+
clarity around the nature of the violation and an explanation of why the
|
81 |
+
behavior was inappropriate. A public apology may be requested.
|
82 |
+
|
83 |
+
### 2. Warning
|
84 |
+
|
85 |
+
**Community Impact**: A violation through a single incident or series
|
86 |
+
of actions.
|
87 |
+
|
88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
89 |
+
interaction with the people involved, including unsolicited interaction with
|
90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
91 |
+
includes avoiding interactions in community spaces as well as external channels
|
92 |
+
like social media. Violating these terms may lead to a temporary or
|
93 |
+
permanent ban.
|
94 |
+
|
95 |
+
### 3. Temporary Ban
|
96 |
+
|
97 |
+
**Community Impact**: A serious violation of community standards, including
|
98 |
+
sustained inappropriate behavior.
|
99 |
+
|
100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
101 |
+
communication with the community for a specified period of time. No public or
|
102 |
+
private interaction with the people involved, including unsolicited interaction
|
103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
104 |
+
Violating these terms may lead to a permanent ban.
|
105 |
+
|
106 |
+
### 4. Permanent Ban
|
107 |
+
|
108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
111 |
+
|
112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
113 |
+
the community.
|
114 |
+
|
115 |
+
## Attribution
|
116 |
+
|
117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
118 |
+
version 2.0, available at
|
119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
120 |
+
|
121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
123 |
+
|
124 |
+
[homepage]: https://www.contributor-covenant.org
|
125 |
+
|
126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
128 |
+
https://www.contributor-covenant.org/translations.
|
DI-engine/CONTRIBUTING.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)
|
2 |
+
|
3 |
+
[GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html)
|
4 |
+
|
5 |
+
- [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
|
6 |
+
- [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
|
7 |
+
- [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review)
|
DI-engine/LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright 2017 Google Inc.
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
DI-engine/Makefile
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CI ?=
|
2 |
+
|
3 |
+
# Directory variables
|
4 |
+
DING_DIR ?= ./ding
|
5 |
+
DIZOO_DIR ?= ./dizoo
|
6 |
+
RANGE_DIR ?=
|
7 |
+
TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
|
8 |
+
COV_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
|
9 |
+
FORMAT_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
|
10 |
+
PLATFORM_TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR}/entry/tests/test_serial_entry.py ${DING_DIR}/entry/tests/test_serial_entry_onpolicy.py)
|
11 |
+
|
12 |
+
# Workers command
|
13 |
+
WORKERS ?= 2
|
14 |
+
WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,)
|
15 |
+
|
16 |
+
# Duration command
|
17 |
+
DURATIONS ?= 10
|
18 |
+
DURATIONS_COMMAND := $(if ${DURATIONS},--durations=${DURATIONS},)
|
19 |
+
|
20 |
+
docs:
|
21 |
+
$(MAKE) -C ${DING_DIR}/docs html
|
22 |
+
|
23 |
+
unittest:
|
24 |
+
pytest ${TEST_DIR} \
|
25 |
+
--cov-report=xml \
|
26 |
+
--cov-report term-missing \
|
27 |
+
--cov=${COV_DIR} \
|
28 |
+
${DURATIONS_COMMAND} \
|
29 |
+
${WORKERS_COMMAND} \
|
30 |
+
-sv -m unittest \
|
31 |
+
|
32 |
+
algotest:
|
33 |
+
pytest ${TEST_DIR} \
|
34 |
+
${DURATIONS_COMMAND} \
|
35 |
+
-sv -m algotest
|
36 |
+
|
37 |
+
cudatest:
|
38 |
+
pytest ${TEST_DIR} \
|
39 |
+
-sv -m cudatest
|
40 |
+
|
41 |
+
envpooltest:
|
42 |
+
pytest ${TEST_DIR} \
|
43 |
+
-sv -m envpooltest
|
44 |
+
|
45 |
+
dockertest:
|
46 |
+
${DING_DIR}/scripts/docker-test-entry.sh
|
47 |
+
|
48 |
+
platformtest:
|
49 |
+
pytest ${TEST_DIR} \
|
50 |
+
--cov-report term-missing \
|
51 |
+
--cov=${COV_DIR} \
|
52 |
+
${WORKERS_COMMAND} \
|
53 |
+
-sv -m platformtest
|
54 |
+
|
55 |
+
benchmark:
|
56 |
+
pytest ${TEST_DIR} \
|
57 |
+
--durations=0 \
|
58 |
+
-sv -m benchmark
|
59 |
+
|
60 |
+
test: unittest # just for compatibility, can be changed later
|
61 |
+
|
62 |
+
cpu_test: unittest algotest benchmark
|
63 |
+
|
64 |
+
all_test: unittest algotest cudatest benchmark
|
65 |
+
|
66 |
+
format:
|
67 |
+
yapf --in-place --recursive -p --verbose --style .style.yapf ${FORMAT_DIR}
|
68 |
+
format_test:
|
69 |
+
bash format.sh ${FORMAT_DIR} --test
|
70 |
+
flake_check:
|
71 |
+
flake8 ${FORMAT_DIR}
|
DI-engine/README.md
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<a href="https://di-engine-docs.readthedocs.io/en/latest/"><img width="1000px" height="auto" src="https://github.com/opendilab/DI-engine-docs/blob/main/source/images/head_image.png"></a>
|
3 |
+
</div>
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
[![Twitter](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Ftwitter.com%2Fopendilab)](https://twitter.com/opendilab)
|
8 |
+
[![PyPI](https://img.shields.io/pypi/v/DI-engine)](https://pypi.org/project/DI-engine/)
|
9 |
+
![Conda](https://anaconda.org/opendilab/di-engine/badges/version.svg)
|
10 |
+
![Conda update](https://anaconda.org/opendilab/di-engine/badges/latest_release_date.svg)
|
11 |
+
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/DI-engine)
|
12 |
+
![PyTorch Version](https://img.shields.io/badge/dynamic/json?color=blue&label=pytorch&query=%24.pytorchVersion&url=https%3A%2F%2Fgist.githubusercontent.com/PaParaZz1/54c5c44eeb94734e276b2ed5770eba8d/raw/85b94a54933a9369f8843cc2cea3546152a75661/badges.json)
|
13 |
+
|
14 |
+
![Loc](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/3690cccd811e4c5f771075c2f785c7bb/raw/loc.json)
|
15 |
+
![Comments](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/3690cccd811e4c5f771075c2f785c7bb/raw/comments.json)
|
16 |
+
|
17 |
+
![Style](https://github.com/opendilab/DI-engine/actions/workflows/style.yml/badge.svg)
|
18 |
+
[![Read en Docs](https://github.com/opendilab/DI-engine/actions/workflows/doc.yml/badge.svg)](https://di-engine-docs.readthedocs.io/en/latest)
|
19 |
+
[![Read zh_CN Docs](https://img.shields.io/readthedocs/di-engine-docs?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://di-engine-docs.readthedocs.io/zh_CN/latest)
|
20 |
+
![Unittest](https://github.com/opendilab/DI-engine/actions/workflows/unit_test.yml/badge.svg)
|
21 |
+
![Algotest](https://github.com/opendilab/DI-engine/actions/workflows/algo_test.yml/badge.svg)
|
22 |
+
![deploy](https://github.com/opendilab/DI-engine/actions/workflows/deploy.yml/badge.svg)
|
23 |
+
[![codecov](https://codecov.io/gh/opendilab/DI-engine/branch/main/graph/badge.svg?token=B0Q15JI301)](https://codecov.io/gh/opendilab/DI-engine)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
![GitHub Org's stars](https://img.shields.io/github/stars/opendilab)
|
28 |
+
[![GitHub stars](https://img.shields.io/github/stars/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/stargazers)
|
29 |
+
[![GitHub forks](https://img.shields.io/github/forks/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/network)
|
30 |
+
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/opendilab/DI-engine)
|
31 |
+
[![GitHub issues](https://img.shields.io/github/issues/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/issues)
|
32 |
+
[![GitHub pulls](https://img.shields.io/github/issues-pr/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/pulls)
|
33 |
+
[![Contributors](https://img.shields.io/github/contributors/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/graphs/contributors)
|
34 |
+
[![GitHub license](https://img.shields.io/github/license/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/blob/master/LICENSE)
|
35 |
+
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/OpenDILabCommunity)
|
36 |
+
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models?search=opendilab)
|
37 |
+
|
38 |
+
Updated on 2023.12.05 DI-engine-v0.5.0
|
39 |
+
|
40 |
+
|
41 |
+
## Introduction to DI-engine
|
42 |
+
[Documentation](https://di-engine-docs.readthedocs.io/en/latest/) | [中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/) | [Tutorials](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/index.html) | [Feature](#feature) | [Task & Middleware](https://di-engine-docs.readthedocs.io/en/latest/03_system/index.html) | [TreeTensor](#general-data-container-treetensor) | [Roadmap](https://github.com/opendilab/DI-engine/issues/548)
|
43 |
+
|
44 |
+
**DI-engine** is a generalized decision intelligence engine for PyTorch and JAX.
|
45 |
+
|
46 |
+
It provides **python-first** and **asynchronous-native** task and middleware abstractions, and modularly integrates several of the most important decision-making concepts: Env, Policy and Model. Based on the above mechanisms, DI-engine supports **various [deep reinforcement learning](https://di-engine-docs.readthedocs.io/en/latest/10_concepts/index.html) algorithms** with superior performance, high efficiency, well-organized [documentation](https://di-engine-docs.readthedocs.io/en/latest/) and [unittest](https://github.com/opendilab/DI-engine/actions):
|
47 |
+
|
48 |
+
- Most basic DRL algorithms: such as DQN, Rainbow, PPO, TD3, SAC, R2D2, IMPALA
|
49 |
+
- Multi-agent RL algorithms: such as QMIX, WQMIX, MAPPO, HAPPO, ACE
|
50 |
+
- Imitation learning algorithms (BC/IRL/GAIL): such as GAIL, SQIL, Guided Cost Learning, Implicit BC
|
51 |
+
- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
|
52 |
+
- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3, MuZero
|
53 |
+
- Exploration algorithms: HER, RND, ICM, NGU
|
54 |
+
- LLM + RL Algorithms: PPO-max, DPO, MPDPO
|
55 |
+
- Other algorithms: such as PER, PLR, PCGrad
|
56 |
+
|
57 |
+
**DI-engine** aims to **standardize different Decision Intelligence environments and applications**, supporting both academic research and prototype applications. Various training pipelines and customized decision AI applications are also supported:
|
58 |
+
|
59 |
+
<details open>
|
60 |
+
<summary>(Click to Collapse)</summary>
|
61 |
+
|
62 |
+
- Traditional academic environments
|
63 |
+
- [DI-zoo](https://github.com/opendilab/DI-engine#environment-versatility): various decision intelligence demonstrations and benchmark environments with DI-engine.
|
64 |
+
- Tutorial courses
|
65 |
+
- [PPOxFamily](https://github.com/opendilab/PPOxFamily): PPO x Family DRL Tutorial Course
|
66 |
+
- Real world decision AI applications
|
67 |
+
- [DI-star](https://github.com/opendilab/DI-star): Decision AI in StarCraftII
|
68 |
+
- [DI-drive](https://github.com/opendilab/DI-drive): Auto-driving platform
|
69 |
+
- [DI-sheep](https://github.com/opendilab/DI-sheep): Decision AI in 3 Tiles Game
|
70 |
+
- [DI-smartcross](https://github.com/opendilab/DI-smartcross): Decision AI in Traffic Light Control
|
71 |
+
- [DI-bioseq](https://github.com/opendilab/DI-bioseq): Decision AI in Biological Sequence Prediction and Searching
|
72 |
+
- [DI-1024](https://github.com/opendilab/DI-1024): Deep Reinforcement Learning + 1024 Game
|
73 |
+
- Research paper
|
74 |
+
- [InterFuser](https://github.com/opendilab/InterFuser): [CoRL 2022] Safety-Enhanced Autonomous Driving Using Interpretable Sensor Fusion Transformer
|
75 |
+
- [ACE](https://github.com/opendilab/ACE): [AAAI 2023] ACE: Cooperative Multi-agent Q-learning with Bidirectional Action-Dependency
|
76 |
+
- [GoBigger](https://github.com/opendilab/GoBigger): [ICLR 2023] Multi-Agent Decision Intelligence Environment
|
77 |
+
- [DOS](https://github.com/opendilab/DOS): [CVPR 2023] ReasonNet: End-to-End Driving with Temporal and Global Reasoning
|
78 |
+
- [LightZero](https://github.com/opendilab/LightZero): [NeurIPS 2023 Spotlight] A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
|
79 |
+
- [SO2](https://github.com/opendilab/SO2): [AAAI 2024] A Perspective of Q-value Estimation on Offline-to-Online Reinforcement Learning
|
80 |
+
- [LMDrive](https://github.com/opendilab/LMDrive): LMDrive: Closed-Loop End-to-End Driving with Large Language Models
|
81 |
+
- Docs and Tutorials
|
82 |
+
- [DI-engine-docs](https://github.com/opendilab/DI-engine-docs): Tutorials, best practice and the API reference.
|
83 |
+
- [awesome-model-based-RL](https://github.com/opendilab/awesome-model-based-RL): A curated list of awesome Model-Based RL resources
|
84 |
+
- [awesome-exploration-RL](https://github.com/opendilab/awesome-exploration-rl): A curated list of awesome exploration RL resources
|
85 |
+
- [awesome-decision-transformer](https://github.com/opendilab/awesome-decision-transformer): A curated list of Decision Transformer resources
|
86 |
+
- [awesome-RLHF](https://github.com/opendilab/awesome-RLHF): A curated list of reinforcement learning with human feedback resources
|
87 |
+
- [awesome-multi-modal-reinforcement-learning](https://github.com/opendilab/awesome-multi-modal-reinforcement-learning): A curated list of Multi-Modal Reinforcement Learning resources
|
88 |
+
- [awesome-AI-based-protein-design](https://github.com/opendilab/awesome-AI-based-protein-design): a collection of research papers for AI-based protein design
|
89 |
+
- [awesome-diffusion-model-in-rl](https://github.com/opendilab/awesome-diffusion-model-in-rl): A curated list of Diffusion Model in RL resources
|
90 |
+
- [awesome-end-to-end-autonomous-driving](https://github.com/opendilab/awesome-end-to-end-autonomous-driving): A curated list of awesome End-to-End Autonomous Driving resources
|
91 |
+
- [awesome-driving-behavior-prediction](https://github.com/opendilab/awesome-driving-behavior-prediction): A collection of research papers for Driving Behavior Prediction
|
92 |
+
</details>
|
93 |
+
|
94 |
+
On the low-level end, DI-engine comes with a set of highly re-usable modules, including [RL optimization functions](https://github.com/opendilab/DI-engine/tree/main/ding/rl_utils), [PyTorch utilities](https://github.com/opendilab/DI-engine/tree/main/ding/torch_utils) and [auxiliary tools](https://github.com/opendilab/DI-engine/tree/main/ding/utils).
|
95 |
+
|
96 |
+
BTW, **DI-engine** also has some special **system optimization and design** for efficient and robust large-scale RL training:
|
97 |
+
|
98 |
+
<details close>
|
99 |
+
<summary>(Click for Details)</summary>
|
100 |
+
|
101 |
+
- [treevalue](https://github.com/opendilab/treevalue): Tree-nested data structure
|
102 |
+
- [DI-treetensor](https://github.com/opendilab/DI-treetensor): Tree-nested PyTorch tensor Lib
|
103 |
+
- [DI-toolkit](https://github.com/opendilab/DI-toolkit): A simple toolkit package for decision intelligence
|
104 |
+
- [DI-orchestrator](https://github.com/opendilab/DI-orchestrator): RL Kubernetes Custom Resource and Operator Lib
|
105 |
+
- [DI-hpc](https://github.com/opendilab/DI-hpc): RL HPC OP Lib
|
106 |
+
- [DI-store](https://github.com/opendilab/DI-store): RL Object Store
|
107 |
+
</details>
|
108 |
+
|
109 |
+
Have fun with exploration and exploitation.
|
110 |
+
|
111 |
+
## Outline
|
112 |
+
|
113 |
+
- [Introduction to DI-engine](#introduction-to-di-engine)
|
114 |
+
- [Outline](#outline)
|
115 |
+
- [Installation](#installation)
|
116 |
+
- [Quick Start](#quick-start)
|
117 |
+
- [Feature](#feature)
|
118 |
+
- [Algorithm Versatility](#algorithm-versatility)
|
119 |
+
- [Environment Versatility](#environment-versatility)
|
120 |
+
- [General Data Container: TreeTensor](#general-data-container-treetensor)
|
121 |
+
- [Feedback and Contribution](#feedback-and-contribution)
|
122 |
+
- [Supporters](#supporters)
|
123 |
+
- [↳ Stargazers](#-stargazers)
|
124 |
+
- [↳ Forkers](#-forkers)
|
125 |
+
- [Citation](#citation)
|
126 |
+
- [License](#license)
|
127 |
+
|
128 |
+
## Installation
|
129 |
+
|
130 |
+
You can simply install DI-engine from PyPI with the following command:
|
131 |
+
```bash
|
132 |
+
pip install DI-engine
|
133 |
+
```
|
134 |
+
|
135 |
+
If you use Anaconda or Miniconda, you can install DI-engine from conda-forge through the following command:
|
136 |
+
```bash
|
137 |
+
conda install -c opendilab di-engine
|
138 |
+
```
|
139 |
+
|
140 |
+
For more information about installation, you can refer to [installation](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/installation.html).
|
141 |
+
|
142 |
+
And our dockerhub repo can be found [here](https://hub.docker.com/repository/docker/opendilab/ding),we prepare `base image` and `env image` with common RL environments.
|
143 |
+
|
144 |
+
<details close>
|
145 |
+
<summary>(Click for Details)</summary>
|
146 |
+
|
147 |
+
- base: opendilab/ding:nightly
|
148 |
+
- rpc: opendilab/ding:nightly-rpc
|
149 |
+
- atari: opendilab/ding:nightly-atari
|
150 |
+
- mujoco: opendilab/ding:nightly-mujoco
|
151 |
+
- dmc: opendilab/ding:nightly-dmc2gym
|
152 |
+
- metaworld: opendilab/ding:nightly-metaworld
|
153 |
+
- smac: opendilab/ding:nightly-smac
|
154 |
+
- grf: opendilab/ding:nightly-grf
|
155 |
+
- cityflow: opendilab/ding:nightly-cityflow
|
156 |
+
- evogym: opendilab/ding:nightly-evogym
|
157 |
+
- d4rl: opendilab/ding:nightly-d4rl
|
158 |
+
</details>
|
159 |
+
|
160 |
+
The detailed documentation are hosted on [doc](https://di-engine-docs.readthedocs.io/en/latest/) | [中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/).
|
161 |
+
|
162 |
+
## Quick Start
|
163 |
+
|
164 |
+
[3 Minutes Kickoff](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/first_rl_program.html)
|
165 |
+
|
166 |
+
[3 Minutes Kickoff (colab)](https://colab.research.google.com/drive/1_7L-QFDfeCvMvLJzRyBRUW5_Q6ESXcZ4)
|
167 |
+
|
168 |
+
[DI-engine Huggingface Kickoff (colab)](https://colab.research.google.com/drive/1UH1GQOjcHrmNSaW77hnLGxFJrLSLwCOk)
|
169 |
+
|
170 |
+
[How to migrate a new **RL Env**](https://di-engine-docs.readthedocs.io/en/latest/11_dizoo/index.html) | [如何迁移一个新的**强化学习环境**](https://di-engine-docs.readthedocs.io/zh_CN/latest/11_dizoo/index_zh.html)
|
171 |
+
|
172 |
+
[How to customize the neural network model](https://di-engine-docs.readthedocs.io/en/latest/04_best_practice/custom_model.html) | [如何定制策略使用的**神经网络模型**](https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/custom_model_zh.html)
|
173 |
+
|
174 |
+
[测试/部署 **强化学习策略** 的样例](https://github.com/opendilab/DI-engine/blob/main/dizoo/classic_control/cartpole/entry/cartpole_c51_deploy.py)
|
175 |
+
|
176 |
+
[新老 pipeline 的异同对比](https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/diff_in_new_pipeline_zh.html)
|
177 |
+
|
178 |
+
|
179 |
+
## Feature
|
180 |
+
### Algorithm Versatility
|
181 |
+
|
182 |
+
<details open>
|
183 |
+
<summary>(Click to Collapse)</summary>
|
184 |
+
|
185 |
+
![discrete](https://img.shields.io/badge/-discrete-brightgreen) discrete means discrete action space, which is only label in normal DRL algorithms (1-23)
|
186 |
+
|
187 |
+
![continuous](https://img.shields.io/badge/-continous-green) means continuous action space, which is only label in normal DRL algorithms (1-23)
|
188 |
+
|
189 |
+
![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) means hybrid (discrete + continuous) action space (1-23)
|
190 |
+
|
191 |
+
![dist](https://img.shields.io/badge/-distributed-blue) [Distributed Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/distributed_rl.html)|[分布式强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/distributed_rl_zh.html)
|
192 |
+
|
193 |
+
![MARL](https://img.shields.io/badge/-MARL-yellow) [Multi-Agent Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/multi_agent_cooperation_rl.html)|[多智能体强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/multi_agent_cooperation_rl_zh.html)
|
194 |
+
|
195 |
+
![exp](https://img.shields.io/badge/-exploration-orange) [Exploration Mechanisms in Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/exploration_rl.html)|[强化学习中的探索机制](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/exploration_rl_zh.html)
|
196 |
+
|
197 |
+
![IL](https://img.shields.io/badge/-IL-purple) [Imitation Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/imitation_learning.html)|[模仿学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/imitation_learning_zh.html)
|
198 |
+
|
199 |
+
![offline](https://img.shields.io/badge/-offlineRL-darkblue) [Offiline Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/offline_rl.html)|[离线强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/offline_rl_zh.html)
|
200 |
+
|
201 |
+
|
202 |
+
![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) [Model-Based Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/model_based_rl.html)|[基于模型的强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/model_based_rl_zh.html)
|
203 |
+
|
204 |
+
![other](https://img.shields.io/badge/-other-lightgrey) means other sub-direction algorithms, usually as plugin-in in the whole pipeline
|
205 |
+
|
206 |
+
P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
| No. | Algorithm | Label | Doc and Implementation | Runnable Demo |
|
211 |
+
| :--: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
|
212 |
+
| 1 | [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [DQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqn.html)<br>[DQN中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/dqn_zh.html)<br>[policy/dqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u cartpole_dqn_main.py / ding -m serial -c cartpole_dqn_config.py -s 0 |
|
213 |
+
| 2 | [C51](https://arxiv.org/pdf/1707.06887.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [C51 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/c51.html)<br>[policy/c51](https://github.com/opendilab/DI-engine/blob/main/ding/policy/c51.py) | ding -m serial -c cartpole_c51_config.py -s 0 |
|
214 |
+
| 3 | [QRDQN](https://arxiv.org/pdf/1710.10044.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [QRDQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qrdqn.html)<br>[policy/qrdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qrdqn.py) | ding -m serial -c cartpole_qrdqn_config.py -s 0 |
|
215 |
+
| 4 | [IQN](https://arxiv.org/pdf/1806.06923.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/iqn.html)<br>[policy/iqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/iqn.py) | ding -m serial -c cartpole_iqn_config.py -s 0 |
|
216 |
+
| 5 | [FQF](https://arxiv.org/pdf/1911.02140.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [FQF doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/fqf.html)<br>[policy/fqf](https://github.com/opendilab/DI-engine/blob/main/ding/policy/fqf.py) | ding -m serial -c cartpole_fqf_config.py -s 0 |
|
217 |
+
| 6 | [Rainbow](https://arxiv.org/pdf/1710.02298.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [Rainbow doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rainbow.html)<br>[policy/rainbow](https://github.com/opendilab/DI-engine/blob/main/ding/policy/rainbow.py) | ding -m serial -c cartpole_rainbow_config.py -s 0 |
|
218 |
+
| 7 | [SQL](https://arxiv.org/pdf/1702.08165.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [SQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sql.html)<br>[policy/sql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sql.py) | ding -m serial -c cartpole_sql_config.py -s 0 |
|
219 |
+
| 8 | [R2D2](https://openreview.net/forum?id=r1lyTjAqYX) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [R2D2 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d2.html)<br>[policy/r2d2](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d2.py) | ding -m serial -c cartpole_r2d2_config.py -s 0 |
|
220 |
+
| 9 | [PG](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)<br>[policy/pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pg.py) | ding -m serial -c cartpole_pg_config.py -s 0 |
|
221 |
+
| 10 | [PromptPG](https://arxiv.org/abs/2209.14610) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/prompt_pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_pg.py) | ding -m serial_onpolicy -c tabmwp_pg_config.py -s 0 |
|
222 |
+
| 11 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)<br>[policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 |
|
223 |
+
| 12 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html)<br>[policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 |
|
224 |
+
| 13 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html)<br>[policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py |
|
225 |
+
| 14 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html)<br>[policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 |
|
226 |
+
| 15 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html)<br>[policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 |
|
227 |
+
| 16 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 |
|
228 |
+
| 17 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html)<br>[policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 |
|
229 |
+
| 18 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html)<br>[policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
|
230 |
+
| 19 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html)<br>[policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 |
|
231 |
+
| 20 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 |
|
232 |
+
| 21 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 |
|
233 |
+
| 22 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 |
|
234 |
+
| 23 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
|
235 |
+
| 24 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py |
|
236 |
+
| 25 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html)<br>[policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
|
237 |
+
| 26 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html)<br>[policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
|
238 |
+
| 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
|
239 |
+
| 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)<br>[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
|
240 |
+
| 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)<br>[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
|
241 |
+
| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
|
242 |
+
| 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)<br>[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
|
243 |
+
| 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)<br>[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
|
244 |
+
| 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)<br>[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
|
245 |
+
| 34 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html)<br>[R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html)<br>[policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py |
|
246 |
+
| 35 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html)<br>[reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py |
|
247 |
+
| 36 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html)<br>[reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py |
|
248 |
+
| 37 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py) <br> [model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py |
|
249 |
+
| 38 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py |
|
250 |
+
| 39 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html)<br>[reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
|
251 |
+
| 40 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html)<br>[reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py |
|
252 |
+
| 41 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html)<br>[ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html)<br>[reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
|
253 |
+
| 42 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)<br>[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
|
254 |
+
| 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)<br>[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
|
255 |
+
| 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u d4rl_dt_mujoco.py |
|
256 |
+
| 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
|
257 |
+
| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
|
258 |
+
| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
|
259 |
+
| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)<br>[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
|
260 |
+
| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
|
261 |
+
| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
|
262 |
+
| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
|
263 |
+
| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
|
264 |
+
| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
|
265 |
+
| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
|
266 |
+
| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
|
267 |
+
</details>
|
268 |
+
|
269 |
+
|
270 |
+
### Environment Versatility
|
271 |
+
<details open>
|
272 |
+
<summary>(Click to Collapse)</summary>
|
273 |
+
|
274 |
+
| No | Environment | Label | Visualization | Code and Doc Links |
|
275 |
+
| :--: | :--------------------------------------: | :---------------------------------: | :--------------------------------:|:---------------------------------------------------------: |
|
276 |
+
| 1 | [Atari](https://github.com/openai/gym/tree/master/gym/envs/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/atari/atari.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/atari/envs) <br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/atari_zh.html) |
|
277 |
+
| 2 | [box2d/bipedalwalker](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/bipedalwalker/original.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/bipedalwalker/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/bipedalwalker.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bipedalwalker_zh.html) |
|
278 |
+
| 3 | [box2d/lunarlander](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/box2d/lunarlander/lunarlander.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/lunarlander/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/lunarlander.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/lunarlander_zh.html) |
|
279 |
+
| 4 | [classic_control/cartpole](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/cartpole/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/cartpole.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/cartpole_zh.html) |
|
280 |
+
| 5 | [classic_control/pendulum](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/classic_control/pendulum/pendulum.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/pendulum/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pendulum.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pendulum_zh.html) |
|
281 |
+
| 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/competitive_rl_zh.html) |
|
282 |
+
| 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gfootball.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gfootball_zh.html) |
|
283 |
+
| 8 | [minigrid](https://github.com/maximecb/gym-minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/minigrid/minigrid.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/minigrid/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid_zh.html) |
|
284 |
+
| 9 | [MuJoCo](https://github.com/openai/gym/tree/master/gym/envs/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/majoco/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco_zh.html) |
|
285 |
+
| 10 | [PettingZoo](https://github.com/Farama-Foundation/PettingZoo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/petting_zoo/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pettingzoo.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pettingzoo_zh.html) |
|
286 |
+
| 11 | [overcooked](https://github.com/HumanCompatibleAI/overcooked-demo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/overcooked/overcooked.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/overcooded/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/overcooked.html) |
|
287 |
+
| 12 | [procgen](https://github.com/openai/procgen) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/procgen/coinrun.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/procgen)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/procgen.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/procgen_zh.html) |
|
288 |
+
| 13 | [pybullet](https://github.com/benelot/pybullet-gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/pybullet/pybullet.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pybullet/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pybullet_zh.html) |
|
289 |
+
| 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/smac.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/smac_zh.html) |
|
290 |
+
| 15 | [d4rl](https://github.com/rail-berkeley/d4rl) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | ![ori](dizoo/d4rl/d4rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/d4rl)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/d4rl_zh.html) |
|
291 |
+
| 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) |
|
292 |
+
| 17 | pomdp atari | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pomdp/envs) |
|
293 |
+
| 18 | [bsuite](https://github.com/deepmind/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/bsuite/bsuite.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bsuite/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs//bsuite.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bsuite_zh.html) |
|
294 |
+
| 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/image_cls_zh.html) |
|
295 |
+
| 20 | [slime_volleyball](https://github.com/hardmaru/slimevolleygym) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](dizoo/slime_volley/slime_volley.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/slime_volley)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/slime_volleyball.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/slime_volleyball_zh.html) |
|
296 |
+
| 21 | [gym_hybrid](https://github.com/thomashirtz/gym-hybrid) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_hybrid/moving_v0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_hybrid)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gym_hybrid.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_hybrid_zh.html) |
|
297 |
+
| 22 | [GoBigger](https://github.com/opendilab/GoBigger) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen)![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](./dizoo/gobigger_overview.gif) | [dizoo link](https://github.com/opendilab/GoBigger-Challenge-2021/tree/main/di_baseline)<br>[env tutorial](https://gobigger.readthedocs.io/en/latest/index.html)<br>[环境指南](https://gobigger.readthedocs.io/zh_CN/latest/) |
|
298 |
+
| 23 | [gym_soccer](https://github.com/openai/gym-soccer) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_soccer/half_offensive.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_soccer)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_soccer_zh.html) |
|
299 |
+
| 24 |[multiagent_mujoco](https://github.com/schroederdewitt/multiagent_mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_mujoco/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/mujoco_zh.html) |
|
300 |
+
| 25 |bitflip | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/bitflip/bitflip.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bitflip/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bitflip_zh.html) |
|
301 |
+
| 26 |[sokoban](https://github.com/mpSchrader/gym-sokoban) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![Game 2](https://github.com/mpSchrader/gym-sokoban/raw/default/docs/Animations/solved_4.gif?raw=true) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/sokoban/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/sokoban.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/sokoban_zh.html) |
|
302 |
+
| 27 |[gym_anytrading](https://github.com/AminHP/gym-anytrading) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/gym_anytrading/envs/position.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_anytrading) <br> [env tutorial](https://github.com/opendilab/DI-engine/blob/main/dizoo/gym_anytrading/envs/README.md) |
|
303 |
+
| 28 |[mario](https://github.com/Kautenja/gym-super-mario-bros) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/mario/mario.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/mario) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gym_super_mario_bros.html) <br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_super_mario_bros_zh.html) |
|
304 |
+
| 29 |[dmc2gym](https://github.com/denisyarats/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/dmc2gym/dmc2gym_cheetah.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/dmc2gym)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/dmc2gym.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/dmc2gym_zh.html) |
|
305 |
+
| 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/evogym.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/Evogym_zh.html) |
|
306 |
+
| 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym_pybullet_drones/gym_pybullet_drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)<br>环境指南 |
|
307 |
+
| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)<br>环境指南 |
|
308 |
+
| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/acrobot_zh.html) |
|
309 |
+
| 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) <br> ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)<br>环境指南 |
|
310 |
+
| 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
|
311 |
+
| 36 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)<br> env tutorial <br> 环境指南 |
|
312 |
+
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南|
|
313 |
+
|
314 |
+
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
|
315 |
+
|
316 |
+
![continuous](https://img.shields.io/badge/-continous-green) means continuous action space
|
317 |
+
|
318 |
+
![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) means hybrid (discrete + continuous) action space
|
319 |
+
|
320 |
+
![MARL](https://img.shields.io/badge/-MARL-yellow) means multi-agent RL environment
|
321 |
+
|
322 |
+
![sparse](https://img.shields.io/badge/-sparse%20reward-orange) means environment which is related to exploration and sparse reward
|
323 |
+
|
324 |
+
![offline](https://img.shields.io/badge/-offlineRL-darkblue) means offline RL environment
|
325 |
+
|
326 |
+
![IL](https://img.shields.io/badge/-IL/SL-purple) means Imitation Learning or Supervised Learning Dataset
|
327 |
+
|
328 |
+
![selfplay](https://img.shields.io/badge/-selfplay-blue) means environment that allows agent VS agent battle
|
329 |
+
|
330 |
+
P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also the sparse reward type.
|
331 |
+
</details>
|
332 |
+
|
333 |
+
|
334 |
+
### General Data Container: TreeTensor
|
335 |
+
|
336 |
+
DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as the basic data container in various components, which is ease of use and consistent across different code modules such as environment definition, data processing and DRL optimization. Here are some concrete code examples:
|
337 |
+
|
338 |
+
- TreeTensor can easily extend all the operations of `torch.Tensor` to nested data:
|
339 |
+
<details close>
|
340 |
+
<summary>(Click for Details)</summary>
|
341 |
+
|
342 |
+
```python
|
343 |
+
import treetensor.torch as ttorch
|
344 |
+
|
345 |
+
|
346 |
+
# create random tensor
|
347 |
+
data = ttorch.randn({'a': (3, 2), 'b': {'c': (3, )}})
|
348 |
+
# clone+detach tensor
|
349 |
+
data_clone = data.clone().detach()
|
350 |
+
# access tree structure like attribute
|
351 |
+
a = data.a
|
352 |
+
c = data.b.c
|
353 |
+
# stack/cat/split
|
354 |
+
stacked_data = ttorch.stack([data, data_clone], 0)
|
355 |
+
cat_data = ttorch.cat([data, data_clone], 0)
|
356 |
+
data, data_clone = ttorch.split(stacked_data, 1)
|
357 |
+
# reshape
|
358 |
+
data = data.unsqueeze(-1)
|
359 |
+
data = data.squeeze(-1)
|
360 |
+
flatten_data = data.view(-1)
|
361 |
+
# indexing
|
362 |
+
data_0 = data[0]
|
363 |
+
data_1to2 = data[1:2]
|
364 |
+
# execute math calculations
|
365 |
+
data = data.sin()
|
366 |
+
data.b.c.cos_().clamp_(-1, 1)
|
367 |
+
data += data ** 2
|
368 |
+
# backward
|
369 |
+
data.requires_grad_(True)
|
370 |
+
loss = data.arctan().mean()
|
371 |
+
loss.backward()
|
372 |
+
# print shape
|
373 |
+
print(data.shape)
|
374 |
+
# result
|
375 |
+
# <Size 0x7fbd3346ddc0>
|
376 |
+
# ├── 'a' --> torch.Size([1, 3, 2])
|
377 |
+
# └── 'b' --> <Size 0x7fbd3346dd00>
|
378 |
+
# └── 'c' --> torch.Size([1, 3])
|
379 |
+
```
|
380 |
+
|
381 |
+
</details>
|
382 |
+
|
383 |
+
- TreeTensor can make it simple yet effective to implement classic deep reinforcement learning pipeline
|
384 |
+
<details close>
|
385 |
+
<summary>(Click for Details)</summary>
|
386 |
+
|
387 |
+
```diff
|
388 |
+
import torch
|
389 |
+
import treetensor.torch as ttorch
|
390 |
+
|
391 |
+
B = 4
|
392 |
+
|
393 |
+
|
394 |
+
def get_item():
|
395 |
+
return {
|
396 |
+
'obs': {
|
397 |
+
'scalar': torch.randn(12),
|
398 |
+
'image': torch.randn(3, 32, 32),
|
399 |
+
},
|
400 |
+
'action': torch.randint(0, 10, size=(1,)),
|
401 |
+
'reward': torch.rand(1),
|
402 |
+
'done': False,
|
403 |
+
}
|
404 |
+
|
405 |
+
|
406 |
+
data = [get_item() for _ in range(B)]
|
407 |
+
|
408 |
+
|
409 |
+
# execute `stack` op
|
410 |
+
- def stack(data, dim):
|
411 |
+
- elem = data[0]
|
412 |
+
- if isinstance(elem, torch.Tensor):
|
413 |
+
- return torch.stack(data, dim)
|
414 |
+
- elif isinstance(elem, dict):
|
415 |
+
- return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
|
416 |
+
- elif isinstance(elem, bool):
|
417 |
+
- return torch.BoolTensor(data)
|
418 |
+
- else:
|
419 |
+
- raise TypeError("not support elem type: {}".format(type(elem)))
|
420 |
+
- stacked_data = stack(data, dim=0)
|
421 |
+
+ data = [ttorch.tensor(d) for d in data]
|
422 |
+
+ stacked_data = ttorch.stack(data, dim=0)
|
423 |
+
|
424 |
+
# validate
|
425 |
+
- assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
|
426 |
+
- assert stacked_data['action'].shape == (B, 1)
|
427 |
+
- assert stacked_data['reward'].shape == (B, 1)
|
428 |
+
- assert stacked_data['done'].shape == (B,)
|
429 |
+
- assert stacked_data['done'].dtype == torch.bool
|
430 |
+
+ assert stacked_data.obs.image.shape == (B, 3, 32, 32)
|
431 |
+
+ assert stacked_data.action.shape == (B, 1)
|
432 |
+
+ assert stacked_data.reward.shape == (B, 1)
|
433 |
+
+ assert stacked_data.done.shape == (B,)
|
434 |
+
+ assert stacked_data.done.dtype == torch.bool
|
435 |
+
```
|
436 |
+
|
437 |
+
</details>
|
438 |
+
|
439 |
+
## Feedback and Contribution
|
440 |
+
|
441 |
+
- [File an issue](https://github.com/opendilab/DI-engine/issues/new/choose) on Github
|
442 |
+
- Open or participate in our [forum](https://github.com/opendilab/DI-engine/discussions)
|
443 |
+
- Discuss on DI-engine [slack communication channel](https://join.slack.com/t/opendilab/shared_invite/zt-v9tmv4fp-nUBAQEH1_Kuyu_q4plBssQ)
|
444 |
+
- Discuss on DI-engine's WeChat group (i.e. add us on WeChat: ding314assist)
|
445 |
+
|
446 |
+
<img src=https://github.com/opendilab/DI-engine/blob/main/assets/wechat.jpeg width=35% />
|
447 |
+
- Contact our email (opendilab@pjlab.org.cn)
|
448 |
+
- Contributes to our future plan [Roadmap](https://github.com/opendilab/DI-engine/issues/548)
|
449 |
+
|
450 |
+
We appreciate all the feedbacks and contributions to improve DI-engine, both algorithms and system designs. And `CONTRIBUTING.md` offers some necessary information.
|
451 |
+
|
452 |
+
## Supporters
|
453 |
+
|
454 |
+
### ↳ Stargazers
|
455 |
+
|
456 |
+
[![Stargazers repo roster for @opendilab/DI-engine](https://reporoster.com/stars/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/stargazers)
|
457 |
+
|
458 |
+
### ↳ Forkers
|
459 |
+
|
460 |
+
[![Forkers repo roster for @opendilab/DI-engine](https://reporoster.com/forks/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/network/members)
|
461 |
+
|
462 |
+
|
463 |
+
## Citation
|
464 |
+
```latex
|
465 |
+
@misc{ding,
|
466 |
+
title={DI-engine: OpenDILab Decision Intelligence Engine},
|
467 |
+
author={OpenDILab Contributors},
|
468 |
+
publisher={GitHub},
|
469 |
+
howpublished={\url{https://github.com/opendilab/DI-engine}},
|
470 |
+
year={2021},
|
471 |
+
}
|
472 |
+
```
|
473 |
+
|
474 |
+
## License
|
475 |
+
DI-engine released under the Apache 2.0 license.
|
DI-engine/cloc.sh
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# This scripts counts the lines of code and comments in all source files
|
4 |
+
# and prints the results to the command line. It uses the commandline tool
|
5 |
+
# "cloc". You can either pass --loc, --comments or --percentage to show the
|
6 |
+
# respective values only.
|
7 |
+
# Some parts below need to be adapted to your project!
|
8 |
+
|
9 |
+
# Get the location of this script.
|
10 |
+
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
|
11 |
+
|
12 |
+
# Run cloc - this counts code lines, blank lines and comment lines
|
13 |
+
# for the specified languages. You will need to change this accordingly.
|
14 |
+
# For C++, you could use "C++,C/C++ Header" for example.
|
15 |
+
# We are only interested in the summary, therefore the tail -1
|
16 |
+
SUMMARY="$(cloc "${SCRIPT_DIR}" --include-lang="Python" --md | tail -1)"
|
17 |
+
|
18 |
+
# The $SUMMARY is one line of a markdown table and looks like this:
|
19 |
+
# SUM:|101|3123|2238|10783
|
20 |
+
# We use the following command to split it into an array.
|
21 |
+
IFS='|' read -r -a TOKENS <<< "$SUMMARY"
|
22 |
+
|
23 |
+
# Store the individual tokens for better readability.
|
24 |
+
NUMBER_OF_FILES=${TOKENS[1]}
|
25 |
+
COMMENT_LINES=${TOKENS[3]}
|
26 |
+
LINES_OF_CODE=${TOKENS[4]}
|
27 |
+
|
28 |
+
# To make the estimate of commented lines more accurate, we have to
|
29 |
+
# subtract any copyright header which is included in each file.
|
30 |
+
# For Fly-Pie, this header has the length of five lines.
|
31 |
+
# All dumb comments like those /////////// or those // ------------
|
32 |
+
# are also subtracted. As cloc does not count inline comments,
|
33 |
+
# the overall estimate should be rather conservative.
|
34 |
+
# Change the lines below according to your project.
|
35 |
+
DUMB_COMMENTS="$(grep -r -E '//////|// -----' "${SCRIPT_DIR}" | wc -l)"
|
36 |
+
COMMENT_LINES=$(($COMMENT_LINES - 5 * $NUMBER_OF_FILES - $DUMB_COMMENTS))
|
37 |
+
|
38 |
+
# Print all results if no arguments are given.
|
39 |
+
if [[ $# -eq 0 ]] ; then
|
40 |
+
awk -v a=$LINES_OF_CODE \
|
41 |
+
'BEGIN {printf "Lines of source code: %6.1fk\n", a/1000}'
|
42 |
+
awk -v a=$COMMENT_LINES \
|
43 |
+
'BEGIN {printf "Lines of comments: %6.1fk\n", a/1000}'
|
44 |
+
awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
|
45 |
+
'BEGIN {printf "Comment Percentage: %6.1f%\n", 100*a/b}'
|
46 |
+
exit 0
|
47 |
+
fi
|
48 |
+
|
49 |
+
# Show lines of code if --loc is given.
|
50 |
+
if [[ $* == *--loc* ]]
|
51 |
+
then
|
52 |
+
awk -v a=$LINES_OF_CODE \
|
53 |
+
'BEGIN {printf "%.1fk\n", a/1000}'
|
54 |
+
fi
|
55 |
+
|
56 |
+
# Show lines of comments if --comments is given.
|
57 |
+
if [[ $* == *--comments* ]]
|
58 |
+
then
|
59 |
+
awk -v a=$COMMENT_LINES \
|
60 |
+
'BEGIN {printf "%.1fk\n", a/1000}'
|
61 |
+
fi
|
62 |
+
|
63 |
+
# Show precentage of comments if --percentage is given.
|
64 |
+
if [[ $* == *--percentage* ]]
|
65 |
+
then
|
66 |
+
awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
|
67 |
+
'BEGIN {printf "%.1f\n", 100*a/b}'
|
68 |
+
fi
|
69 |
+
|
DI-engine/codecov.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
coverage:
|
2 |
+
status:
|
3 |
+
project:
|
4 |
+
default:
|
5 |
+
# basic
|
6 |
+
target: auto
|
7 |
+
threshold: 0.5%
|
8 |
+
if_ci_failed: success #success, failure, error, ignore
|
DI-engine/conda/conda_build_config.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
python:
|
2 |
+
- 3.7
|
DI-engine/conda/meta.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{% set data = load_setup_py_data() %}
|
2 |
+
package:
|
3 |
+
name: di-engine
|
4 |
+
version: v0.5.0
|
5 |
+
|
6 |
+
source:
|
7 |
+
path: ..
|
8 |
+
|
9 |
+
build:
|
10 |
+
number: 0
|
11 |
+
script: python -m pip install . -vv
|
12 |
+
entry_points:
|
13 |
+
- ding = ding.entry.cli:cli
|
14 |
+
|
15 |
+
requirements:
|
16 |
+
build:
|
17 |
+
- python
|
18 |
+
- setuptools
|
19 |
+
run:
|
20 |
+
- python
|
21 |
+
|
22 |
+
test:
|
23 |
+
imports:
|
24 |
+
- ding
|
25 |
+
- dizoo
|
26 |
+
|
27 |
+
about:
|
28 |
+
home: https://github.com/opendilab/DI-engine
|
29 |
+
license: Apache-2.0
|
30 |
+
license_file: LICENSE
|
31 |
+
summary: DI-engine is a generalized Decision Intelligence engine (https://github.com/opendilab/DI-engine).
|
32 |
+
description: Please refer to https://di-engine-docs.readthedocs.io/en/latest/00_intro/index.html#what-is-di-engine
|
33 |
+
dev_url: https://github.com/opendilab/DI-engine
|
34 |
+
doc_url: https://di-engine-docs.readthedocs.io/en/latest/index.html
|
35 |
+
doc_source_url: https://github.com/opendilab/DI-engine-docs
|
DI-engine/ding/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
__TITLE__ = 'DI-engine'
|
4 |
+
__VERSION__ = 'v0.5.0'
|
5 |
+
__DESCRIPTION__ = 'Decision AI Engine'
|
6 |
+
__AUTHOR__ = "OpenDILab Contributors"
|
7 |
+
__AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn"
|
8 |
+
__version__ = __VERSION__
|
9 |
+
|
10 |
+
enable_hpc_rl = os.environ.get('ENABLE_DI_HPC', 'false').lower() == 'true'
|
11 |
+
enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true'
|
12 |
+
enable_numba = True
|
DI-engine/ding/bonus/__init__.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ding.config
|
2 |
+
from .a2c import A2CAgent
|
3 |
+
from .c51 import C51Agent
|
4 |
+
from .ddpg import DDPGAgent
|
5 |
+
from .dqn import DQNAgent
|
6 |
+
from .pg import PGAgent
|
7 |
+
from .ppof import PPOF
|
8 |
+
from .ppo_offpolicy import PPOOffPolicyAgent
|
9 |
+
from .sac import SACAgent
|
10 |
+
from .sql import SQLAgent
|
11 |
+
from .td3 import TD3Agent
|
12 |
+
|
13 |
+
supported_algo = dict(
|
14 |
+
A2C=A2CAgent,
|
15 |
+
C51=C51Agent,
|
16 |
+
DDPG=DDPGAgent,
|
17 |
+
DQN=DQNAgent,
|
18 |
+
PG=PGAgent,
|
19 |
+
PPOF=PPOF,
|
20 |
+
PPOOffPolicy=PPOOffPolicyAgent,
|
21 |
+
SAC=SACAgent,
|
22 |
+
SQL=SQLAgent,
|
23 |
+
TD3=TD3Agent,
|
24 |
+
)
|
25 |
+
|
26 |
+
supported_algo_list = list(supported_algo.keys())
|
27 |
+
|
28 |
+
|
29 |
+
def env_supported(algo: str = None) -> list:
|
30 |
+
"""
|
31 |
+
return list of the envs that supported by di-engine.
|
32 |
+
"""
|
33 |
+
|
34 |
+
if algo is not None:
|
35 |
+
if algo.upper() == "A2C":
|
36 |
+
return list(ding.config.example.A2C.supported_env.keys())
|
37 |
+
elif algo.upper() == "C51":
|
38 |
+
return list(ding.config.example.C51.supported_env.keys())
|
39 |
+
elif algo.upper() == "DDPG":
|
40 |
+
return list(ding.config.example.DDPG.supported_env.keys())
|
41 |
+
elif algo.upper() == "DQN":
|
42 |
+
return list(ding.config.example.DQN.supported_env.keys())
|
43 |
+
elif algo.upper() == "PG":
|
44 |
+
return list(ding.config.example.PG.supported_env.keys())
|
45 |
+
elif algo.upper() == "PPOF":
|
46 |
+
return list(ding.config.example.PPOF.supported_env.keys())
|
47 |
+
elif algo.upper() == "PPOOFFPOLICY":
|
48 |
+
return list(ding.config.example.PPOOffPolicy.supported_env.keys())
|
49 |
+
elif algo.upper() == "SAC":
|
50 |
+
return list(ding.config.example.SAC.supported_env.keys())
|
51 |
+
elif algo.upper() == "SQL":
|
52 |
+
return list(ding.config.example.SQL.supported_env.keys())
|
53 |
+
elif algo.upper() == "TD3":
|
54 |
+
return list(ding.config.example.TD3.supported_env.keys())
|
55 |
+
else:
|
56 |
+
raise ValueError("The algo {} is not supported by di-engine.".format(algo))
|
57 |
+
else:
|
58 |
+
supported_env = set()
|
59 |
+
supported_env.update(ding.config.example.A2C.supported_env.keys())
|
60 |
+
supported_env.update(ding.config.example.C51.supported_env.keys())
|
61 |
+
supported_env.update(ding.config.example.DDPG.supported_env.keys())
|
62 |
+
supported_env.update(ding.config.example.DQN.supported_env.keys())
|
63 |
+
supported_env.update(ding.config.example.PG.supported_env.keys())
|
64 |
+
supported_env.update(ding.config.example.PPOF.supported_env.keys())
|
65 |
+
supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys())
|
66 |
+
supported_env.update(ding.config.example.SAC.supported_env.keys())
|
67 |
+
supported_env.update(ding.config.example.SQL.supported_env.keys())
|
68 |
+
supported_env.update(ding.config.example.TD3.supported_env.keys())
|
69 |
+
# return the list of the envs
|
70 |
+
return list(supported_env)
|
71 |
+
|
72 |
+
|
73 |
+
supported_env = env_supported()
|
74 |
+
|
75 |
+
|
76 |
+
def algo_supported(env_id: str = None) -> list:
|
77 |
+
"""
|
78 |
+
return list of the algos that supported by di-engine.
|
79 |
+
"""
|
80 |
+
if env_id is not None:
|
81 |
+
algo = []
|
82 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]:
|
83 |
+
algo.append("A2C")
|
84 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]:
|
85 |
+
algo.append("C51")
|
86 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]:
|
87 |
+
algo.append("DDPG")
|
88 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]:
|
89 |
+
algo.append("DQN")
|
90 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]:
|
91 |
+
algo.append("PG")
|
92 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]:
|
93 |
+
algo.append("PPOF")
|
94 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]:
|
95 |
+
algo.append("PPOOffPolicy")
|
96 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]:
|
97 |
+
algo.append("SAC")
|
98 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]:
|
99 |
+
algo.append("SQL")
|
100 |
+
if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]:
|
101 |
+
algo.append("TD3")
|
102 |
+
|
103 |
+
if len(algo) == 0:
|
104 |
+
raise ValueError("The env {} is not supported by di-engine.".format(env_id))
|
105 |
+
return algo
|
106 |
+
else:
|
107 |
+
return supported_algo_list
|
108 |
+
|
109 |
+
|
110 |
+
def is_supported(env_id: str = None, algo: str = None) -> bool:
|
111 |
+
"""
|
112 |
+
Check if the env-algo pair is supported by di-engine.
|
113 |
+
"""
|
114 |
+
if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]:
|
115 |
+
if algo is not None and algo.upper() in supported_algo_list:
|
116 |
+
if env_id.upper() in env_supported(algo):
|
117 |
+
return True
|
118 |
+
else:
|
119 |
+
return False
|
120 |
+
elif algo is None:
|
121 |
+
return True
|
122 |
+
else:
|
123 |
+
return False
|
124 |
+
elif env_id is None:
|
125 |
+
if algo is not None and algo.upper() in supported_algo_list:
|
126 |
+
return True
|
127 |
+
elif algo is None:
|
128 |
+
raise ValueError("Please specify the env or algo.")
|
129 |
+
else:
|
130 |
+
return False
|
131 |
+
else:
|
132 |
+
return False
|
DI-engine/ding/bonus/a2c.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, trainer, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
|
11 |
+
gae_estimator, final_ctx_saver
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import A2CPolicy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import VAC
|
19 |
+
from ding.model import model_wrap
|
20 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
21 |
+
from ding.config.example.A2C import supported_env_cfg
|
22 |
+
from ding.config.example.A2C import supported_env
|
23 |
+
|
24 |
+
|
25 |
+
class A2CAgent:
|
26 |
+
"""
|
27 |
+
Overview:
|
28 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
29 |
+
Advantage Actor Critic(A2C).
|
30 |
+
For more information about the system design of RL agent, please refer to \
|
31 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
32 |
+
Interface:
|
33 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
34 |
+
"""
|
35 |
+
supported_env_list = list(supported_env_cfg.keys())
|
36 |
+
"""
|
37 |
+
Overview:
|
38 |
+
List of supported envs.
|
39 |
+
Examples:
|
40 |
+
>>> from ding.bonus.a2c import A2CAgent
|
41 |
+
>>> print(A2CAgent.supported_env_list)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
env_id: str = None,
|
47 |
+
env: BaseEnv = None,
|
48 |
+
seed: int = 0,
|
49 |
+
exp_name: str = None,
|
50 |
+
model: Optional[torch.nn.Module] = None,
|
51 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
52 |
+
policy_state_dict: str = None,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Overview:
|
56 |
+
Initialize agent for A2C algorithm.
|
57 |
+
Arguments:
|
58 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
59 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
60 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
61 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
62 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
63 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
64 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
65 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
66 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
67 |
+
Default to 0.
|
68 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
69 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
70 |
+
- model (:obj:`torch.nn.Module`): The model of A2C algorithm, which should be an instance of class \
|
71 |
+
:class:`ding.model.VAC`. \
|
72 |
+
If not specified, a default model will be generated according to the configuration.
|
73 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of A2C algorithm, which is a dict. \
|
74 |
+
Default to None. If not specified, the default configuration will be used. \
|
75 |
+
The default configuration can be found in ``ding/config/example/A2C/gym_lunarlander_v2.py``.
|
76 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
77 |
+
If specified, the policy will be loaded from this file. Default to None.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
81 |
+
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
|
82 |
+
and we want to train an agent with A2C algorithm with default configuration. \
|
83 |
+
Then we can initialize the agent in the following ways:
|
84 |
+
>>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
|
85 |
+
or, if we want can specify the env_id in the configuration:
|
86 |
+
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
|
87 |
+
>>> agent = A2CAgent(cfg=cfg)
|
88 |
+
There are also other arguments to specify the agent when initializing.
|
89 |
+
For example, if we want to specify the environment instance:
|
90 |
+
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
|
91 |
+
>>> agent = A2CAgent(cfg=cfg, env=env)
|
92 |
+
or, if we want to specify the model:
|
93 |
+
>>> model = VAC(**cfg.policy.model)
|
94 |
+
>>> agent = A2CAgent(cfg=cfg, model=model)
|
95 |
+
or, if we want to reload the policy from a saved policy state dict:
|
96 |
+
>>> agent = A2CAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
|
97 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
101 |
+
|
102 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
103 |
+
cfg = EasyDict(cfg)
|
104 |
+
|
105 |
+
if env_id is not None:
|
106 |
+
assert env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
|
107 |
+
A2CAgent.supported_env_list
|
108 |
+
)
|
109 |
+
if cfg is None:
|
110 |
+
cfg = supported_env_cfg[env_id]
|
111 |
+
else:
|
112 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
113 |
+
else:
|
114 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
115 |
+
assert cfg.env.env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
|
116 |
+
A2CAgent.supported_env_list
|
117 |
+
)
|
118 |
+
default_policy_config = EasyDict({"policy": A2CPolicy.default_config()})
|
119 |
+
default_policy_config.update(cfg)
|
120 |
+
cfg = default_policy_config
|
121 |
+
|
122 |
+
if exp_name is not None:
|
123 |
+
cfg.exp_name = exp_name
|
124 |
+
self.cfg = compile_config(cfg, policy=A2CPolicy)
|
125 |
+
self.exp_name = self.cfg.exp_name
|
126 |
+
if env is None:
|
127 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
128 |
+
else:
|
129 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
130 |
+
self.env = env
|
131 |
+
|
132 |
+
logging.getLogger().setLevel(logging.INFO)
|
133 |
+
self.seed = seed
|
134 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
135 |
+
if not os.path.exists(self.exp_name):
|
136 |
+
os.makedirs(self.exp_name)
|
137 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
138 |
+
if model is None:
|
139 |
+
model = VAC(**self.cfg.policy.model)
|
140 |
+
self.policy = A2CPolicy(self.cfg.policy, model=model)
|
141 |
+
if policy_state_dict is not None:
|
142 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
143 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
144 |
+
|
145 |
+
def train(
|
146 |
+
self,
|
147 |
+
step: int = int(1e7),
|
148 |
+
collector_env_num: int = 4,
|
149 |
+
evaluator_env_num: int = 4,
|
150 |
+
n_iter_log_show: int = 500,
|
151 |
+
n_iter_save_ckpt: int = 1000,
|
152 |
+
context: Optional[str] = None,
|
153 |
+
debug: bool = False,
|
154 |
+
wandb_sweep: bool = False,
|
155 |
+
) -> TrainingReturn:
|
156 |
+
"""
|
157 |
+
Overview:
|
158 |
+
Train the agent with A2C algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
159 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
160 |
+
recorded and saved by wandb.
|
161 |
+
Arguments:
|
162 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
163 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
164 |
+
If not specified, it will be set according to the configuration.
|
165 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
166 |
+
If not specified, it will be set according to the configuration.
|
167 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
168 |
+
Default to 1000.
|
169 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
170 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
171 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
172 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
173 |
+
subprocess environment manager will be used.
|
174 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
175 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
176 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
177 |
+
Returns:
|
178 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
179 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if debug:
|
183 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
184 |
+
logging.debug(self.policy._model)
|
185 |
+
# define env and policy
|
186 |
+
collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
|
187 |
+
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
|
188 |
+
|
189 |
+
with task.start(ctx=OnlineRLContext()):
|
190 |
+
task.use(
|
191 |
+
interaction_evaluator(
|
192 |
+
self.cfg,
|
193 |
+
self.policy.eval_mode,
|
194 |
+
evaluator_env,
|
195 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
196 |
+
)
|
197 |
+
)
|
198 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
199 |
+
task.use(
|
200 |
+
StepCollector(
|
201 |
+
self.cfg,
|
202 |
+
self.policy.collect_mode,
|
203 |
+
collector_env,
|
204 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
205 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
206 |
+
)
|
207 |
+
)
|
208 |
+
task.use(gae_estimator(self.cfg, self.policy.collect_mode))
|
209 |
+
task.use(trainer(self.cfg, self.policy.learn_mode))
|
210 |
+
task.use(
|
211 |
+
wandb_online_logger(
|
212 |
+
metric_list=self.policy._monitor_vars_learn(),
|
213 |
+
model=self.policy._model,
|
214 |
+
anonymous=True,
|
215 |
+
project_name=self.exp_name,
|
216 |
+
wandb_sweep=wandb_sweep,
|
217 |
+
)
|
218 |
+
)
|
219 |
+
task.use(termination_checker(max_env_step=step))
|
220 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
221 |
+
task.run()
|
222 |
+
|
223 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
224 |
+
|
225 |
+
def deploy(
|
226 |
+
self,
|
227 |
+
enable_save_replay: bool = False,
|
228 |
+
concatenate_all_replay: bool = False,
|
229 |
+
replay_save_path: str = None,
|
230 |
+
seed: Optional[Union[int, List]] = None,
|
231 |
+
debug: bool = False
|
232 |
+
) -> EvalReturn:
|
233 |
+
"""
|
234 |
+
Overview:
|
235 |
+
Deploy the agent with A2C algorithm by interacting with the environment, during which the replay video \
|
236 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
237 |
+
Arguments:
|
238 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
239 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
240 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
241 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
242 |
+
the replay video of each episode will be saved separately.
|
243 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
244 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
245 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
246 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
247 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
248 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
249 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
250 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
251 |
+
subprocess environment manager will be used.
|
252 |
+
Returns:
|
253 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
254 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
255 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
256 |
+
"""
|
257 |
+
|
258 |
+
if debug:
|
259 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
260 |
+
# define env and policy
|
261 |
+
env = self.env.clone(caller='evaluator')
|
262 |
+
|
263 |
+
if seed is not None and isinstance(seed, int):
|
264 |
+
seeds = [seed]
|
265 |
+
elif seed is not None and isinstance(seed, list):
|
266 |
+
seeds = seed
|
267 |
+
else:
|
268 |
+
seeds = [self.seed]
|
269 |
+
|
270 |
+
returns = []
|
271 |
+
images = []
|
272 |
+
if enable_save_replay:
|
273 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
274 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
275 |
+
else:
|
276 |
+
logging.warning('No video would be generated during the deploy.')
|
277 |
+
if concatenate_all_replay:
|
278 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
279 |
+
concatenate_all_replay = False
|
280 |
+
|
281 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
282 |
+
|
283 |
+
if self.cfg.policy.action_space == 'continuous':
|
284 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
|
285 |
+
elif self.cfg.policy.action_space == 'discrete':
|
286 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
|
287 |
+
else:
|
288 |
+
raise NotImplementedError
|
289 |
+
|
290 |
+
def _forward(obs):
|
291 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
292 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
293 |
+
if cuda and torch.cuda.is_available():
|
294 |
+
obs = obs.cuda()
|
295 |
+
action = forward_fn(obs, mode='compute_actor')["action"]
|
296 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
297 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
298 |
+
return action
|
299 |
+
|
300 |
+
return _forward
|
301 |
+
|
302 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
303 |
+
|
304 |
+
# reset first to make sure the env is in the initial state
|
305 |
+
# env will be reset again in the main loop
|
306 |
+
env.reset()
|
307 |
+
|
308 |
+
for seed in seeds:
|
309 |
+
env.seed(seed, dynamic_seed=False)
|
310 |
+
return_ = 0.
|
311 |
+
step = 0
|
312 |
+
obs = env.reset()
|
313 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
314 |
+
while True:
|
315 |
+
action = forward_fn(obs)
|
316 |
+
obs, rew, done, info = env.step(action)
|
317 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
318 |
+
return_ += rew
|
319 |
+
step += 1
|
320 |
+
if done:
|
321 |
+
break
|
322 |
+
logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
|
323 |
+
returns.append(return_)
|
324 |
+
|
325 |
+
env.close()
|
326 |
+
|
327 |
+
if concatenate_all_replay:
|
328 |
+
images = np.concatenate(images, axis=0)
|
329 |
+
import imageio
|
330 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
331 |
+
|
332 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
333 |
+
|
334 |
+
def collect_data(
|
335 |
+
self,
|
336 |
+
env_num: int = 8,
|
337 |
+
save_data_path: Optional[str] = None,
|
338 |
+
n_sample: Optional[int] = None,
|
339 |
+
n_episode: Optional[int] = None,
|
340 |
+
context: Optional[str] = None,
|
341 |
+
debug: bool = False
|
342 |
+
) -> None:
|
343 |
+
"""
|
344 |
+
Overview:
|
345 |
+
Collect data with A2C algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
346 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
347 |
+
``exp_name/demo_data``.
|
348 |
+
Arguments:
|
349 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
350 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
351 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
352 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
353 |
+
If not specified, ``n_episode`` must be specified.
|
354 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
355 |
+
If not specified, ``n_sample`` must be specified.
|
356 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
357 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
358 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
359 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
360 |
+
subprocess environment manager will be used.
|
361 |
+
"""
|
362 |
+
|
363 |
+
if debug:
|
364 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
365 |
+
if n_episode is not None:
|
366 |
+
raise NotImplementedError
|
367 |
+
# define env and policy
|
368 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
369 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
370 |
+
|
371 |
+
if save_data_path is None:
|
372 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
373 |
+
|
374 |
+
# main execution task
|
375 |
+
with task.start(ctx=OnlineRLContext()):
|
376 |
+
task.use(
|
377 |
+
StepCollector(
|
378 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
379 |
+
)
|
380 |
+
)
|
381 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
382 |
+
task.run(max_step=1)
|
383 |
+
logging.info(
|
384 |
+
f'A2C collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
385 |
+
)
|
386 |
+
|
387 |
+
def batch_evaluate(
|
388 |
+
self,
|
389 |
+
env_num: int = 4,
|
390 |
+
n_evaluator_episode: int = 4,
|
391 |
+
context: Optional[str] = None,
|
392 |
+
debug: bool = False
|
393 |
+
) -> EvalReturn:
|
394 |
+
"""
|
395 |
+
Overview:
|
396 |
+
Evaluate the agent with A2C algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
397 |
+
environments. The evaluation result will be returned.
|
398 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
399 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
400 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
401 |
+
Arguments:
|
402 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
403 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
404 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
405 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
406 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
407 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
408 |
+
subprocess environment manager will be used.
|
409 |
+
Returns:
|
410 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
411 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
412 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
413 |
+
"""
|
414 |
+
|
415 |
+
if debug:
|
416 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
417 |
+
# define env and policy
|
418 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
419 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
420 |
+
|
421 |
+
# reset first to make sure the env is in the initial state
|
422 |
+
# env will be reset again in the main loop
|
423 |
+
env.launch()
|
424 |
+
env.reset()
|
425 |
+
|
426 |
+
evaluate_cfg = self.cfg
|
427 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
428 |
+
|
429 |
+
# main execution task
|
430 |
+
with task.start(ctx=OnlineRLContext()):
|
431 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
432 |
+
task.run(max_step=1)
|
433 |
+
|
434 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
435 |
+
|
436 |
+
@property
|
437 |
+
def best(self) -> 'A2CAgent':
|
438 |
+
"""
|
439 |
+
Overview:
|
440 |
+
Load the best model from the checkpoint directory, \
|
441 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
442 |
+
The return value is the agent with the best model.
|
443 |
+
Returns:
|
444 |
+
- (:obj:`A2CAgent`): The agent with the best model.
|
445 |
+
Examples:
|
446 |
+
>>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
|
447 |
+
>>> agent.train()
|
448 |
+
>>> agent = agent.best
|
449 |
+
|
450 |
+
.. note::
|
451 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
452 |
+
model will be replaced by the best model.
|
453 |
+
"""
|
454 |
+
|
455 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
456 |
+
# Load best model if it exists
|
457 |
+
if os.path.exists(best_model_file_path):
|
458 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
459 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
460 |
+
return self
|
DI-engine/ding/bonus/c51.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
|
11 |
+
OffPolicyLearner, final_ctx_saver, eps_greedy_handler, nstep_reward_enhancer
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import C51Policy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import C51DQN
|
19 |
+
from ding.model import model_wrap
|
20 |
+
from ding.data import DequeBuffer
|
21 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
22 |
+
from ding.config.example.C51 import supported_env_cfg
|
23 |
+
from ding.config.example.C51 import supported_env
|
24 |
+
|
25 |
+
|
26 |
+
class C51Agent:
|
27 |
+
"""
|
28 |
+
Overview:
|
29 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm C51.
|
30 |
+
For more information about the system design of RL agent, please refer to \
|
31 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
32 |
+
Interface:
|
33 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
34 |
+
"""
|
35 |
+
supported_env_list = list(supported_env_cfg.keys())
|
36 |
+
"""
|
37 |
+
Overview:
|
38 |
+
List of supported envs.
|
39 |
+
Examples:
|
40 |
+
>>> from ding.bonus.c51 import C51Agent
|
41 |
+
>>> print(C51Agent.supported_env_list)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
env_id: str = None,
|
47 |
+
env: BaseEnv = None,
|
48 |
+
seed: int = 0,
|
49 |
+
exp_name: str = None,
|
50 |
+
model: Optional[torch.nn.Module] = None,
|
51 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
52 |
+
policy_state_dict: str = None,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Overview:
|
56 |
+
Initialize agent for C51 algorithm.
|
57 |
+
Arguments:
|
58 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
59 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
60 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
61 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
62 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
63 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
64 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
65 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
66 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
67 |
+
Default to 0.
|
68 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
69 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
70 |
+
- model (:obj:`torch.nn.Module`): The model of C51 algorithm, which should be an instance of class \
|
71 |
+
:class:`ding.model.C51DQN`. \
|
72 |
+
If not specified, a default model will be generated according to the configuration.
|
73 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of C51 algorithm, which is a dict. \
|
74 |
+
Default to None. If not specified, the default configuration will be used. \
|
75 |
+
The default configuration can be found in ``ding/config/example/C51/gym_lunarlander_v2.py``.
|
76 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
77 |
+
If specified, the policy will be loaded from this file. Default to None.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
81 |
+
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
|
82 |
+
and we want to train an agent with C51 algorithm with default configuration. \
|
83 |
+
Then we can initialize the agent in the following ways:
|
84 |
+
>>> agent = C51Agent(env_id='LunarLander-v2')
|
85 |
+
or, if we want can specify the env_id in the configuration:
|
86 |
+
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
|
87 |
+
>>> agent = C51Agent(cfg=cfg)
|
88 |
+
There are also other arguments to specify the agent when initializing.
|
89 |
+
For example, if we want to specify the environment instance:
|
90 |
+
>>> env = CustomizedEnv('LunarLander-v2')
|
91 |
+
>>> agent = C51Agent(cfg=cfg, env=env)
|
92 |
+
or, if we want to specify the model:
|
93 |
+
>>> model = C51DQN(**cfg.policy.model)
|
94 |
+
>>> agent = C51Agent(cfg=cfg, model=model)
|
95 |
+
or, if we want to reload the policy from a saved policy state dict:
|
96 |
+
>>> agent = C51Agent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
|
97 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
101 |
+
|
102 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
103 |
+
cfg = EasyDict(cfg)
|
104 |
+
|
105 |
+
if env_id is not None:
|
106 |
+
assert env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
|
107 |
+
C51Agent.supported_env_list
|
108 |
+
)
|
109 |
+
if cfg is None:
|
110 |
+
cfg = supported_env_cfg[env_id]
|
111 |
+
else:
|
112 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
113 |
+
else:
|
114 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
115 |
+
assert cfg.env.env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
|
116 |
+
C51Agent.supported_env_list
|
117 |
+
)
|
118 |
+
default_policy_config = EasyDict({"policy": C51Policy.default_config()})
|
119 |
+
default_policy_config.update(cfg)
|
120 |
+
cfg = default_policy_config
|
121 |
+
|
122 |
+
if exp_name is not None:
|
123 |
+
cfg.exp_name = exp_name
|
124 |
+
self.cfg = compile_config(cfg, policy=C51Policy)
|
125 |
+
self.exp_name = self.cfg.exp_name
|
126 |
+
if env is None:
|
127 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
128 |
+
else:
|
129 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
130 |
+
self.env = env
|
131 |
+
|
132 |
+
logging.getLogger().setLevel(logging.INFO)
|
133 |
+
self.seed = seed
|
134 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
135 |
+
if not os.path.exists(self.exp_name):
|
136 |
+
os.makedirs(self.exp_name)
|
137 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
138 |
+
if model is None:
|
139 |
+
model = C51DQN(**self.cfg.policy.model)
|
140 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
141 |
+
self.policy = C51Policy(self.cfg.policy, model=model)
|
142 |
+
if policy_state_dict is not None:
|
143 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
144 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
145 |
+
|
146 |
+
def train(
|
147 |
+
self,
|
148 |
+
step: int = int(1e7),
|
149 |
+
collector_env_num: int = None,
|
150 |
+
evaluator_env_num: int = None,
|
151 |
+
n_iter_save_ckpt: int = 1000,
|
152 |
+
context: Optional[str] = None,
|
153 |
+
debug: bool = False,
|
154 |
+
wandb_sweep: bool = False,
|
155 |
+
) -> TrainingReturn:
|
156 |
+
"""
|
157 |
+
Overview:
|
158 |
+
Train the agent with C51 algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
159 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
160 |
+
recorded and saved by wandb.
|
161 |
+
Arguments:
|
162 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
163 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
164 |
+
If not specified, it will be set according to the configuration.
|
165 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
166 |
+
If not specified, it will be set according to the configuration.
|
167 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
168 |
+
Default to 1000.
|
169 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
170 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
171 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
172 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
173 |
+
subprocess environment manager will be used.
|
174 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
175 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
176 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
177 |
+
Returns:
|
178 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
179 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if debug:
|
183 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
184 |
+
logging.debug(self.policy._model)
|
185 |
+
# define env and policy
|
186 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
187 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
188 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
189 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
190 |
+
|
191 |
+
with task.start(ctx=OnlineRLContext()):
|
192 |
+
task.use(
|
193 |
+
interaction_evaluator(
|
194 |
+
self.cfg,
|
195 |
+
self.policy.eval_mode,
|
196 |
+
evaluator_env,
|
197 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
198 |
+
)
|
199 |
+
)
|
200 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
201 |
+
task.use(eps_greedy_handler(self.cfg))
|
202 |
+
task.use(
|
203 |
+
StepCollector(
|
204 |
+
self.cfg,
|
205 |
+
self.policy.collect_mode,
|
206 |
+
collector_env,
|
207 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
208 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
209 |
+
)
|
210 |
+
)
|
211 |
+
task.use(nstep_reward_enhancer(self.cfg))
|
212 |
+
task.use(data_pusher(self.cfg, self.buffer_))
|
213 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
214 |
+
task.use(
|
215 |
+
wandb_online_logger(
|
216 |
+
metric_list=self.policy._monitor_vars_learn(),
|
217 |
+
model=self.policy._model,
|
218 |
+
anonymous=True,
|
219 |
+
project_name=self.exp_name,
|
220 |
+
wandb_sweep=wandb_sweep,
|
221 |
+
)
|
222 |
+
)
|
223 |
+
task.use(termination_checker(max_env_step=step))
|
224 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
225 |
+
task.run()
|
226 |
+
|
227 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
228 |
+
|
229 |
+
def deploy(
|
230 |
+
self,
|
231 |
+
enable_save_replay: bool = False,
|
232 |
+
concatenate_all_replay: bool = False,
|
233 |
+
replay_save_path: str = None,
|
234 |
+
seed: Optional[Union[int, List]] = None,
|
235 |
+
debug: bool = False
|
236 |
+
) -> EvalReturn:
|
237 |
+
"""
|
238 |
+
Overview:
|
239 |
+
Deploy the agent with C51 algorithm by interacting with the environment, during which the replay video \
|
240 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
241 |
+
Arguments:
|
242 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
243 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
244 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
245 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
246 |
+
the replay video of each episode will be saved separately.
|
247 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
248 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
249 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
250 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
251 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
252 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
253 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
254 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
255 |
+
subprocess environment manager will be used.
|
256 |
+
Returns:
|
257 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
258 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
259 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
260 |
+
"""
|
261 |
+
|
262 |
+
if debug:
|
263 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
264 |
+
# define env and policy
|
265 |
+
env = self.env.clone(caller='evaluator')
|
266 |
+
|
267 |
+
if seed is not None and isinstance(seed, int):
|
268 |
+
seeds = [seed]
|
269 |
+
elif seed is not None and isinstance(seed, list):
|
270 |
+
seeds = seed
|
271 |
+
else:
|
272 |
+
seeds = [self.seed]
|
273 |
+
|
274 |
+
returns = []
|
275 |
+
images = []
|
276 |
+
if enable_save_replay:
|
277 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
278 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
279 |
+
else:
|
280 |
+
logging.warning('No video would be generated during the deploy.')
|
281 |
+
if concatenate_all_replay:
|
282 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
283 |
+
concatenate_all_replay = False
|
284 |
+
|
285 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
286 |
+
|
287 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
|
288 |
+
|
289 |
+
def _forward(obs):
|
290 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
291 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
292 |
+
if cuda and torch.cuda.is_available():
|
293 |
+
obs = obs.cuda()
|
294 |
+
action = forward_fn(obs)["action"]
|
295 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
296 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
297 |
+
return action
|
298 |
+
|
299 |
+
return _forward
|
300 |
+
|
301 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
302 |
+
|
303 |
+
# reset first to make sure the env is in the initial state
|
304 |
+
# env will be reset again in the main loop
|
305 |
+
env.reset()
|
306 |
+
|
307 |
+
for seed in seeds:
|
308 |
+
env.seed(seed, dynamic_seed=False)
|
309 |
+
return_ = 0.
|
310 |
+
step = 0
|
311 |
+
obs = env.reset()
|
312 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
313 |
+
while True:
|
314 |
+
action = forward_fn(obs)
|
315 |
+
obs, rew, done, info = env.step(action)
|
316 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
317 |
+
return_ += rew
|
318 |
+
step += 1
|
319 |
+
if done:
|
320 |
+
break
|
321 |
+
logging.info(f'C51 deploy is finished, final episode return with {step} steps is: {return_}')
|
322 |
+
returns.append(return_)
|
323 |
+
|
324 |
+
env.close()
|
325 |
+
|
326 |
+
if concatenate_all_replay:
|
327 |
+
images = np.concatenate(images, axis=0)
|
328 |
+
import imageio
|
329 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
330 |
+
|
331 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
332 |
+
|
333 |
+
def collect_data(
|
334 |
+
self,
|
335 |
+
env_num: int = 8,
|
336 |
+
save_data_path: Optional[str] = None,
|
337 |
+
n_sample: Optional[int] = None,
|
338 |
+
n_episode: Optional[int] = None,
|
339 |
+
context: Optional[str] = None,
|
340 |
+
debug: bool = False
|
341 |
+
) -> None:
|
342 |
+
"""
|
343 |
+
Overview:
|
344 |
+
Collect data with C51 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
345 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
346 |
+
``exp_name/demo_data``.
|
347 |
+
Arguments:
|
348 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
349 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
350 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
351 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
352 |
+
If not specified, ``n_episode`` must be specified.
|
353 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
354 |
+
If not specified, ``n_sample`` must be specified.
|
355 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
356 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
357 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
358 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
359 |
+
subprocess environment manager will be used.
|
360 |
+
"""
|
361 |
+
|
362 |
+
if debug:
|
363 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
364 |
+
if n_episode is not None:
|
365 |
+
raise NotImplementedError
|
366 |
+
# define env and policy
|
367 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
368 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
369 |
+
|
370 |
+
if save_data_path is None:
|
371 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
372 |
+
|
373 |
+
# main execution task
|
374 |
+
with task.start(ctx=OnlineRLContext()):
|
375 |
+
task.use(
|
376 |
+
StepCollector(
|
377 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
378 |
+
)
|
379 |
+
)
|
380 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
381 |
+
task.run(max_step=1)
|
382 |
+
logging.info(
|
383 |
+
f'C51 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
384 |
+
)
|
385 |
+
|
386 |
+
def batch_evaluate(
|
387 |
+
self,
|
388 |
+
env_num: int = 4,
|
389 |
+
n_evaluator_episode: int = 4,
|
390 |
+
context: Optional[str] = None,
|
391 |
+
debug: bool = False
|
392 |
+
) -> EvalReturn:
|
393 |
+
"""
|
394 |
+
Overview:
|
395 |
+
Evaluate the agent with C51 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
396 |
+
environments. The evaluation result will be returned.
|
397 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
398 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
399 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
400 |
+
Arguments:
|
401 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
402 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
403 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
404 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
405 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
406 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
407 |
+
subprocess environment manager will be used.
|
408 |
+
Returns:
|
409 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
410 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
411 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
412 |
+
"""
|
413 |
+
|
414 |
+
if debug:
|
415 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
416 |
+
# define env and policy
|
417 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
418 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
419 |
+
|
420 |
+
# reset first to make sure the env is in the initial state
|
421 |
+
# env will be reset again in the main loop
|
422 |
+
env.launch()
|
423 |
+
env.reset()
|
424 |
+
|
425 |
+
evaluate_cfg = self.cfg
|
426 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
427 |
+
|
428 |
+
# main execution task
|
429 |
+
with task.start(ctx=OnlineRLContext()):
|
430 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
431 |
+
task.run(max_step=1)
|
432 |
+
|
433 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
434 |
+
|
435 |
+
@property
|
436 |
+
def best(self) -> 'C51Agent':
|
437 |
+
"""
|
438 |
+
Overview:
|
439 |
+
Load the best model from the checkpoint directory, \
|
440 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
441 |
+
The return value is the agent with the best model.
|
442 |
+
Returns:
|
443 |
+
- (:obj:`C51Agent`): The agent with the best model.
|
444 |
+
Examples:
|
445 |
+
>>> agent = C51Agent(env_id='LunarLander-v2')
|
446 |
+
>>> agent.train()
|
447 |
+
>>> agent = agent.best
|
448 |
+
|
449 |
+
.. note::
|
450 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
451 |
+
model will be replaced by the best model.
|
452 |
+
"""
|
453 |
+
|
454 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
455 |
+
# Load best model if it exists
|
456 |
+
if os.path.exists(best_model_file_path):
|
457 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
458 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
459 |
+
return self
|
DI-engine/ding/bonus/common.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class TrainingReturn:
|
7 |
+
'''
|
8 |
+
Attributions
|
9 |
+
wandb_url: The weight & biases (wandb) project url of the trainning experiment.
|
10 |
+
'''
|
11 |
+
wandb_url: str
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class EvalReturn:
|
16 |
+
'''
|
17 |
+
Attributions
|
18 |
+
eval_value: The mean of evaluation return.
|
19 |
+
eval_value_std: The standard deviation of evaluation return.
|
20 |
+
'''
|
21 |
+
eval_value: np.float32
|
22 |
+
eval_value_std: np.float32
|
DI-engine/ding/bonus/config.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import os
|
3 |
+
import gym
|
4 |
+
from ding.envs import BaseEnv, DingEnvWrapper
|
5 |
+
from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
|
6 |
+
EvalEpisodeReturnWrapper, TransposeWrapper, TimeLimitWrapper, FlatObsWrapper, GymToGymnasiumWrapper
|
7 |
+
from ding.policy import PPOFPolicy
|
8 |
+
|
9 |
+
|
10 |
+
def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
|
11 |
+
if algorithm == 'PPOF':
|
12 |
+
cfg = PPOFPolicy.default_config()
|
13 |
+
if env_id == 'LunarLander-v2':
|
14 |
+
cfg.n_sample = 512
|
15 |
+
cfg.value_norm = 'popart'
|
16 |
+
cfg.entropy_weight = 1e-3
|
17 |
+
elif env_id == 'LunarLanderContinuous-v2':
|
18 |
+
cfg.action_space = 'continuous'
|
19 |
+
cfg.n_sample = 400
|
20 |
+
elif env_id == 'BipedalWalker-v3':
|
21 |
+
cfg.learning_rate = 1e-3
|
22 |
+
cfg.action_space = 'continuous'
|
23 |
+
cfg.n_sample = 1024
|
24 |
+
elif env_id == 'Pendulum-v1':
|
25 |
+
cfg.action_space = 'continuous'
|
26 |
+
cfg.n_sample = 400
|
27 |
+
elif env_id == 'acrobot':
|
28 |
+
cfg.learning_rate = 1e-4
|
29 |
+
cfg.n_sample = 400
|
30 |
+
elif env_id == 'rocket_landing':
|
31 |
+
cfg.n_sample = 2048
|
32 |
+
cfg.adv_norm = False
|
33 |
+
cfg.model = dict(
|
34 |
+
encoder_hidden_size_list=[64, 64, 128],
|
35 |
+
actor_head_hidden_size=128,
|
36 |
+
critic_head_hidden_size=128,
|
37 |
+
)
|
38 |
+
elif env_id == 'drone_fly':
|
39 |
+
cfg.action_space = 'continuous'
|
40 |
+
cfg.adv_norm = False
|
41 |
+
cfg.epoch_per_collect = 5
|
42 |
+
cfg.learning_rate = 5e-5
|
43 |
+
cfg.n_sample = 640
|
44 |
+
elif env_id == 'hybrid_moving':
|
45 |
+
cfg.action_space = 'hybrid'
|
46 |
+
cfg.n_sample = 3200
|
47 |
+
cfg.entropy_weight = 0.03
|
48 |
+
cfg.batch_size = 320
|
49 |
+
cfg.adv_norm = False
|
50 |
+
cfg.model = dict(
|
51 |
+
encoder_hidden_size_list=[256, 128, 64, 64],
|
52 |
+
sigma_type='fixed',
|
53 |
+
fixed_sigma_value=0.3,
|
54 |
+
bound_type='tanh',
|
55 |
+
)
|
56 |
+
elif env_id == 'evogym_carrier':
|
57 |
+
cfg.action_space = 'continuous'
|
58 |
+
cfg.n_sample = 2048
|
59 |
+
cfg.batch_size = 256
|
60 |
+
cfg.epoch_per_collect = 10
|
61 |
+
cfg.learning_rate = 3e-3
|
62 |
+
elif env_id == 'mario':
|
63 |
+
cfg.n_sample = 256
|
64 |
+
cfg.batch_size = 64
|
65 |
+
cfg.epoch_per_collect = 2
|
66 |
+
cfg.learning_rate = 1e-3
|
67 |
+
cfg.model = dict(
|
68 |
+
encoder_hidden_size_list=[64, 64, 128],
|
69 |
+
critic_head_hidden_size=128,
|
70 |
+
actor_head_hidden_size=128,
|
71 |
+
)
|
72 |
+
elif env_id == 'di_sheep':
|
73 |
+
cfg.n_sample = 3200
|
74 |
+
cfg.batch_size = 320
|
75 |
+
cfg.epoch_per_collect = 10
|
76 |
+
cfg.learning_rate = 3e-4
|
77 |
+
cfg.adv_norm = False
|
78 |
+
cfg.entropy_weight = 0.001
|
79 |
+
elif env_id == 'procgen_bigfish':
|
80 |
+
cfg.n_sample = 16384
|
81 |
+
cfg.batch_size = 16384
|
82 |
+
cfg.epoch_per_collect = 10
|
83 |
+
cfg.learning_rate = 5e-4
|
84 |
+
cfg.model = dict(
|
85 |
+
encoder_hidden_size_list=[64, 128, 256],
|
86 |
+
critic_head_hidden_size=256,
|
87 |
+
actor_head_hidden_size=256,
|
88 |
+
)
|
89 |
+
elif env_id in ['KangarooNoFrameskip-v4', 'BowlingNoFrameskip-v4']:
|
90 |
+
cfg.n_sample = 1024
|
91 |
+
cfg.batch_size = 128
|
92 |
+
cfg.epoch_per_collect = 10
|
93 |
+
cfg.learning_rate = 0.0001
|
94 |
+
cfg.model = dict(
|
95 |
+
encoder_hidden_size_list=[32, 64, 64, 128],
|
96 |
+
actor_head_hidden_size=128,
|
97 |
+
critic_head_hidden_size=128,
|
98 |
+
critic_head_layer_num=2,
|
99 |
+
)
|
100 |
+
elif env_id == 'PongNoFrameskip-v4':
|
101 |
+
cfg.n_sample = 3200
|
102 |
+
cfg.batch_size = 320
|
103 |
+
cfg.epoch_per_collect = 10
|
104 |
+
cfg.learning_rate = 3e-4
|
105 |
+
cfg.model = dict(
|
106 |
+
encoder_hidden_size_list=[64, 64, 128],
|
107 |
+
actor_head_hidden_size=128,
|
108 |
+
critic_head_hidden_size=128,
|
109 |
+
)
|
110 |
+
elif env_id == 'SpaceInvadersNoFrameskip-v4':
|
111 |
+
cfg.n_sample = 320
|
112 |
+
cfg.batch_size = 320
|
113 |
+
cfg.epoch_per_collect = 1
|
114 |
+
cfg.learning_rate = 1e-3
|
115 |
+
cfg.entropy_weight = 0.01
|
116 |
+
cfg.lr_scheduler = (2000, 0.1)
|
117 |
+
cfg.model = dict(
|
118 |
+
encoder_hidden_size_list=[64, 64, 128],
|
119 |
+
actor_head_hidden_size=128,
|
120 |
+
critic_head_hidden_size=128,
|
121 |
+
)
|
122 |
+
elif env_id == 'QbertNoFrameskip-v4':
|
123 |
+
cfg.n_sample = 3200
|
124 |
+
cfg.batch_size = 320
|
125 |
+
cfg.epoch_per_collect = 10
|
126 |
+
cfg.learning_rate = 5e-4
|
127 |
+
cfg.lr_scheduler = (1000, 0.1)
|
128 |
+
cfg.model = dict(
|
129 |
+
encoder_hidden_size_list=[64, 64, 128],
|
130 |
+
actor_head_hidden_size=128,
|
131 |
+
critic_head_hidden_size=128,
|
132 |
+
)
|
133 |
+
elif env_id == 'minigrid_fourroom':
|
134 |
+
cfg.n_sample = 3200
|
135 |
+
cfg.batch_size = 320
|
136 |
+
cfg.learning_rate = 3e-4
|
137 |
+
cfg.epoch_per_collect = 10
|
138 |
+
cfg.entropy_weight = 0.001
|
139 |
+
elif env_id == 'metadrive':
|
140 |
+
cfg.learning_rate = 3e-4
|
141 |
+
cfg.action_space = 'continuous'
|
142 |
+
cfg.entropy_weight = 0.001
|
143 |
+
cfg.n_sample = 3000
|
144 |
+
cfg.epoch_per_collect = 10
|
145 |
+
cfg.learning_rate = 0.0001
|
146 |
+
cfg.model = dict(
|
147 |
+
encoder_hidden_size_list=[32, 64, 64, 128],
|
148 |
+
actor_head_hidden_size=128,
|
149 |
+
critic_head_hidden_size=128,
|
150 |
+
critic_head_layer_num=2,
|
151 |
+
)
|
152 |
+
elif env_id == 'Hopper-v3':
|
153 |
+
cfg.action_space = "continuous"
|
154 |
+
cfg.n_sample = 3200
|
155 |
+
cfg.batch_size = 320
|
156 |
+
cfg.epoch_per_collect = 10
|
157 |
+
cfg.learning_rate = 3e-4
|
158 |
+
elif env_id == 'HalfCheetah-v3':
|
159 |
+
cfg.action_space = "continuous"
|
160 |
+
cfg.n_sample = 3200
|
161 |
+
cfg.batch_size = 320
|
162 |
+
cfg.epoch_per_collect = 10
|
163 |
+
cfg.learning_rate = 3e-4
|
164 |
+
elif env_id == 'Walker2d-v3':
|
165 |
+
cfg.action_space = "continuous"
|
166 |
+
cfg.n_sample = 3200
|
167 |
+
cfg.batch_size = 320
|
168 |
+
cfg.epoch_per_collect = 10
|
169 |
+
cfg.learning_rate = 3e-4
|
170 |
+
else:
|
171 |
+
raise KeyError("not supported env type: {}".format(env_id))
|
172 |
+
else:
|
173 |
+
raise KeyError("not supported algorithm type: {}".format(algorithm))
|
174 |
+
|
175 |
+
return cfg
|
176 |
+
|
177 |
+
|
178 |
+
def get_instance_env(env_id: str) -> BaseEnv:
|
179 |
+
if env_id == 'LunarLander-v2':
|
180 |
+
return DingEnvWrapper(gym.make('LunarLander-v2'))
|
181 |
+
elif env_id == 'LunarLanderContinuous-v2':
|
182 |
+
return DingEnvWrapper(gym.make('LunarLanderContinuous-v2', continuous=True))
|
183 |
+
elif env_id == 'BipedalWalker-v3':
|
184 |
+
return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True, 'rew_clip': True})
|
185 |
+
elif env_id == 'Pendulum-v1':
|
186 |
+
return DingEnvWrapper(gym.make('Pendulum-v1'), cfg={'act_scale': True})
|
187 |
+
elif env_id == 'acrobot':
|
188 |
+
return DingEnvWrapper(gym.make('Acrobot-v1'))
|
189 |
+
elif env_id == 'rocket_landing':
|
190 |
+
from dizoo.rocket.envs import RocketEnv
|
191 |
+
cfg = EasyDict({
|
192 |
+
'task': 'landing',
|
193 |
+
'max_steps': 800,
|
194 |
+
})
|
195 |
+
return RocketEnv(cfg)
|
196 |
+
elif env_id == 'drone_fly':
|
197 |
+
from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
|
198 |
+
cfg = EasyDict({
|
199 |
+
'env_id': 'flythrugate-aviary-v0',
|
200 |
+
'action_type': 'VEL',
|
201 |
+
})
|
202 |
+
return GymPybulletDronesEnv(cfg)
|
203 |
+
elif env_id == 'hybrid_moving':
|
204 |
+
import gym_hybrid
|
205 |
+
return DingEnvWrapper(gym.make('Moving-v0'))
|
206 |
+
elif env_id == 'evogym_carrier':
|
207 |
+
import evogym.envs
|
208 |
+
from evogym import sample_robot, WorldObject
|
209 |
+
path = os.path.join(os.path.dirname(__file__), '../../dizoo/evogym/envs/world_data/carry_bot.json')
|
210 |
+
robot_object = WorldObject.from_json(path)
|
211 |
+
body = robot_object.get_structure()
|
212 |
+
return DingEnvWrapper(
|
213 |
+
gym.make('Carrier-v0', body=body),
|
214 |
+
cfg={
|
215 |
+
'env_wrapper': [
|
216 |
+
lambda env: TimeLimitWrapper(env, max_limit=300),
|
217 |
+
lambda env: EvalEpisodeReturnWrapper(env),
|
218 |
+
]
|
219 |
+
}
|
220 |
+
)
|
221 |
+
elif env_id == 'mario':
|
222 |
+
import gym_super_mario_bros
|
223 |
+
from nes_py.wrappers import JoypadSpace
|
224 |
+
return DingEnvWrapper(
|
225 |
+
JoypadSpace(gym_super_mario_bros.make("SuperMarioBros-1-1-v1"), [["right"], ["right", "A"]]),
|
226 |
+
cfg={
|
227 |
+
'env_wrapper': [
|
228 |
+
lambda env: MaxAndSkipWrapper(env, skip=4),
|
229 |
+
lambda env: WarpFrameWrapper(env, size=84),
|
230 |
+
lambda env: ScaledFloatFrameWrapper(env),
|
231 |
+
lambda env: FrameStackWrapper(env, n_frames=4),
|
232 |
+
lambda env: TimeLimitWrapper(env, max_limit=200),
|
233 |
+
lambda env: EvalEpisodeReturnWrapper(env),
|
234 |
+
]
|
235 |
+
}
|
236 |
+
)
|
237 |
+
elif env_id == 'di_sheep':
|
238 |
+
from sheep_env import SheepEnv
|
239 |
+
return DingEnvWrapper(SheepEnv(level=9))
|
240 |
+
elif env_id == 'procgen_bigfish':
|
241 |
+
return DingEnvWrapper(
|
242 |
+
gym.make('procgen:procgen-bigfish-v0', start_level=0, num_levels=1),
|
243 |
+
cfg={
|
244 |
+
'env_wrapper': [
|
245 |
+
lambda env: TransposeWrapper(env),
|
246 |
+
lambda env: ScaledFloatFrameWrapper(env),
|
247 |
+
lambda env: EvalEpisodeReturnWrapper(env),
|
248 |
+
]
|
249 |
+
},
|
250 |
+
seed_api=False,
|
251 |
+
)
|
252 |
+
elif env_id == 'Hopper-v3':
|
253 |
+
cfg = EasyDict(
|
254 |
+
env_id='Hopper-v3',
|
255 |
+
env_wrapper='mujoco_default',
|
256 |
+
act_scale=True,
|
257 |
+
rew_clip=True,
|
258 |
+
)
|
259 |
+
return DingEnvWrapper(gym.make('Hopper-v3'), cfg=cfg)
|
260 |
+
elif env_id == 'HalfCheetah-v3':
|
261 |
+
cfg = EasyDict(
|
262 |
+
env_id='HalfCheetah-v3',
|
263 |
+
env_wrapper='mujoco_default',
|
264 |
+
act_scale=True,
|
265 |
+
rew_clip=True,
|
266 |
+
)
|
267 |
+
return DingEnvWrapper(gym.make('HalfCheetah-v3'), cfg=cfg)
|
268 |
+
elif env_id == 'Walker2d-v3':
|
269 |
+
cfg = EasyDict(
|
270 |
+
env_id='Walker2d-v3',
|
271 |
+
env_wrapper='mujoco_default',
|
272 |
+
act_scale=True,
|
273 |
+
rew_clip=True,
|
274 |
+
)
|
275 |
+
return DingEnvWrapper(gym.make('Walker2d-v3'), cfg=cfg)
|
276 |
+
|
277 |
+
elif env_id in [
|
278 |
+
'BowlingNoFrameskip-v4',
|
279 |
+
'BreakoutNoFrameskip-v4',
|
280 |
+
'GopherNoFrameskip-v4'
|
281 |
+
'KangarooNoFrameskip-v4',
|
282 |
+
'PongNoFrameskip-v4',
|
283 |
+
'QbertNoFrameskip-v4',
|
284 |
+
'SpaceInvadersNoFrameskip-v4',
|
285 |
+
]:
|
286 |
+
|
287 |
+
cfg = EasyDict({
|
288 |
+
'env_id': env_id,
|
289 |
+
'env_wrapper': 'atari_default',
|
290 |
+
})
|
291 |
+
ding_env_atari = DingEnvWrapper(gym.make(env_id), cfg=cfg)
|
292 |
+
return ding_env_atari
|
293 |
+
elif env_id == 'minigrid_fourroom':
|
294 |
+
import gymnasium
|
295 |
+
return DingEnvWrapper(
|
296 |
+
gymnasium.make('MiniGrid-FourRooms-v0'),
|
297 |
+
cfg={
|
298 |
+
'env_wrapper': [
|
299 |
+
lambda env: GymToGymnasiumWrapper(env),
|
300 |
+
lambda env: FlatObsWrapper(env),
|
301 |
+
lambda env: TimeLimitWrapper(env, max_limit=300),
|
302 |
+
lambda env: EvalEpisodeReturnWrapper(env),
|
303 |
+
]
|
304 |
+
}
|
305 |
+
)
|
306 |
+
elif env_id == 'metadrive':
|
307 |
+
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
|
308 |
+
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
|
309 |
+
cfg = dict(
|
310 |
+
map='XSOS',
|
311 |
+
horizon=4000,
|
312 |
+
out_of_road_penalty=40.0,
|
313 |
+
crash_vehicle_penalty=40.0,
|
314 |
+
out_of_route_done=True,
|
315 |
+
)
|
316 |
+
cfg = EasyDict(cfg)
|
317 |
+
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
|
318 |
+
else:
|
319 |
+
raise KeyError("not supported env type: {}".format(env_id))
|
320 |
+
|
321 |
+
|
322 |
+
def get_hybrid_shape(action_space) -> EasyDict:
|
323 |
+
return EasyDict({
|
324 |
+
'action_type_shape': action_space[0].n,
|
325 |
+
'action_args_shape': action_space[1].shape,
|
326 |
+
})
|
DI-engine/ding/bonus/ddpg.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
|
11 |
+
OffPolicyLearner, final_ctx_saver
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import DDPGPolicy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import ContinuousQAC
|
19 |
+
from ding.data import DequeBuffer
|
20 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
21 |
+
from ding.config.example.DDPG import supported_env_cfg
|
22 |
+
from ding.config.example.DDPG import supported_env
|
23 |
+
|
24 |
+
|
25 |
+
class DDPGAgent:
|
26 |
+
"""
|
27 |
+
Overview:
|
28 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
29 |
+
Deep Deterministic Policy Gradient(DDPG).
|
30 |
+
For more information about the system design of RL agent, please refer to \
|
31 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
32 |
+
Interface:
|
33 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
34 |
+
"""
|
35 |
+
supported_env_list = list(supported_env_cfg.keys())
|
36 |
+
"""
|
37 |
+
Overview:
|
38 |
+
List of supported envs.
|
39 |
+
Examples:
|
40 |
+
>>> from ding.bonus.ddpg import DDPGAgent
|
41 |
+
>>> print(DDPGAgent.supported_env_list)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
env_id: str = None,
|
47 |
+
env: BaseEnv = None,
|
48 |
+
seed: int = 0,
|
49 |
+
exp_name: str = None,
|
50 |
+
model: Optional[torch.nn.Module] = None,
|
51 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
52 |
+
policy_state_dict: str = None,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Overview:
|
56 |
+
Initialize agent for DDPG algorithm.
|
57 |
+
Arguments:
|
58 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
59 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
60 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
61 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
62 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
63 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
64 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
65 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
66 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
67 |
+
Default to 0.
|
68 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
69 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
70 |
+
- model (:obj:`torch.nn.Module`): The model of DDPG algorithm, which should be an instance of class \
|
71 |
+
:class:`ding.model.ContinuousQAC`. \
|
72 |
+
If not specified, a default model will be generated according to the configuration.
|
73 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of DDPG algorithm, which is a dict. \
|
74 |
+
Default to None. If not specified, the default configuration will be used. \
|
75 |
+
The default configuration can be found in ``ding/config/example/DDPG/gym_lunarlander_v2.py``.
|
76 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
77 |
+
If specified, the policy will be loaded from this file. Default to None.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
81 |
+
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
|
82 |
+
and we want to train an agent with DDPG algorithm with default configuration. \
|
83 |
+
Then we can initialize the agent in the following ways:
|
84 |
+
>>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
|
85 |
+
or, if we want can specify the env_id in the configuration:
|
86 |
+
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
|
87 |
+
>>> agent = DDPGAgent(cfg=cfg)
|
88 |
+
There are also other arguments to specify the agent when initializing.
|
89 |
+
For example, if we want to specify the environment instance:
|
90 |
+
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
|
91 |
+
>>> agent = DDPGAgent(cfg=cfg, env=env)
|
92 |
+
or, if we want to specify the model:
|
93 |
+
>>> model = ContinuousQAC(**cfg.policy.model)
|
94 |
+
>>> agent = DDPGAgent(cfg=cfg, model=model)
|
95 |
+
or, if we want to reload the policy from a saved policy state dict:
|
96 |
+
>>> agent = DDPGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
|
97 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
101 |
+
|
102 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
103 |
+
cfg = EasyDict(cfg)
|
104 |
+
|
105 |
+
if env_id is not None:
|
106 |
+
assert env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
|
107 |
+
DDPGAgent.supported_env_list
|
108 |
+
)
|
109 |
+
if cfg is None:
|
110 |
+
cfg = supported_env_cfg[env_id]
|
111 |
+
else:
|
112 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
113 |
+
else:
|
114 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
115 |
+
assert cfg.env.env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
|
116 |
+
DDPGAgent.supported_env_list
|
117 |
+
)
|
118 |
+
default_policy_config = EasyDict({"policy": DDPGPolicy.default_config()})
|
119 |
+
default_policy_config.update(cfg)
|
120 |
+
cfg = default_policy_config
|
121 |
+
|
122 |
+
if exp_name is not None:
|
123 |
+
cfg.exp_name = exp_name
|
124 |
+
self.cfg = compile_config(cfg, policy=DDPGPolicy)
|
125 |
+
self.exp_name = self.cfg.exp_name
|
126 |
+
if env is None:
|
127 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
128 |
+
else:
|
129 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
130 |
+
self.env = env
|
131 |
+
|
132 |
+
logging.getLogger().setLevel(logging.INFO)
|
133 |
+
self.seed = seed
|
134 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
135 |
+
if not os.path.exists(self.exp_name):
|
136 |
+
os.makedirs(self.exp_name)
|
137 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
138 |
+
if model is None:
|
139 |
+
model = ContinuousQAC(**self.cfg.policy.model)
|
140 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
141 |
+
self.policy = DDPGPolicy(self.cfg.policy, model=model)
|
142 |
+
if policy_state_dict is not None:
|
143 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
144 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
145 |
+
|
146 |
+
def train(
|
147 |
+
self,
|
148 |
+
step: int = int(1e7),
|
149 |
+
collector_env_num: int = None,
|
150 |
+
evaluator_env_num: int = None,
|
151 |
+
n_iter_log_show: int = 500,
|
152 |
+
n_iter_save_ckpt: int = 1000,
|
153 |
+
context: Optional[str] = None,
|
154 |
+
debug: bool = False,
|
155 |
+
wandb_sweep: bool = False,
|
156 |
+
) -> TrainingReturn:
|
157 |
+
"""
|
158 |
+
Overview:
|
159 |
+
Train the agent with DDPG algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
160 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
161 |
+
recorded and saved by wandb.
|
162 |
+
Arguments:
|
163 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
164 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
165 |
+
If not specified, it will be set according to the configuration.
|
166 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
167 |
+
If not specified, it will be set according to the configuration.
|
168 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
169 |
+
Default to 1000.
|
170 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
171 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
172 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
173 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
174 |
+
subprocess environment manager will be used.
|
175 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
176 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
177 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
178 |
+
Returns:
|
179 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
180 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
181 |
+
"""
|
182 |
+
|
183 |
+
if debug:
|
184 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
185 |
+
logging.debug(self.policy._model)
|
186 |
+
# define env and policy
|
187 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
188 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
189 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
190 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
191 |
+
|
192 |
+
with task.start(ctx=OnlineRLContext()):
|
193 |
+
task.use(
|
194 |
+
interaction_evaluator(
|
195 |
+
self.cfg,
|
196 |
+
self.policy.eval_mode,
|
197 |
+
evaluator_env,
|
198 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
199 |
+
)
|
200 |
+
)
|
201 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
202 |
+
task.use(
|
203 |
+
StepCollector(
|
204 |
+
self.cfg,
|
205 |
+
self.policy.collect_mode,
|
206 |
+
collector_env,
|
207 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
208 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
209 |
+
)
|
210 |
+
)
|
211 |
+
task.use(data_pusher(self.cfg, self.buffer_))
|
212 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
213 |
+
task.use(
|
214 |
+
wandb_online_logger(
|
215 |
+
metric_list=self.policy._monitor_vars_learn(),
|
216 |
+
model=self.policy._model,
|
217 |
+
anonymous=True,
|
218 |
+
project_name=self.exp_name,
|
219 |
+
wandb_sweep=wandb_sweep,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
task.use(termination_checker(max_env_step=step))
|
223 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
224 |
+
task.run()
|
225 |
+
|
226 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
227 |
+
|
228 |
+
def deploy(
|
229 |
+
self,
|
230 |
+
enable_save_replay: bool = False,
|
231 |
+
concatenate_all_replay: bool = False,
|
232 |
+
replay_save_path: str = None,
|
233 |
+
seed: Optional[Union[int, List]] = None,
|
234 |
+
debug: bool = False
|
235 |
+
) -> EvalReturn:
|
236 |
+
"""
|
237 |
+
Overview:
|
238 |
+
Deploy the agent with DDPG algorithm by interacting with the environment, during which the replay video \
|
239 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
240 |
+
Arguments:
|
241 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
242 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
243 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
244 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
245 |
+
the replay video of each episode will be saved separately.
|
246 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
247 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
248 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
249 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
250 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
251 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
252 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
253 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
254 |
+
subprocess environment manager will be used.
|
255 |
+
Returns:
|
256 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
257 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
258 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
259 |
+
"""
|
260 |
+
|
261 |
+
if debug:
|
262 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
263 |
+
# define env and policy
|
264 |
+
env = self.env.clone(caller='evaluator')
|
265 |
+
|
266 |
+
if seed is not None and isinstance(seed, int):
|
267 |
+
seeds = [seed]
|
268 |
+
elif seed is not None and isinstance(seed, list):
|
269 |
+
seeds = seed
|
270 |
+
else:
|
271 |
+
seeds = [self.seed]
|
272 |
+
|
273 |
+
returns = []
|
274 |
+
images = []
|
275 |
+
if enable_save_replay:
|
276 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
277 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
278 |
+
else:
|
279 |
+
logging.warning('No video would be generated during the deploy.')
|
280 |
+
if concatenate_all_replay:
|
281 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
282 |
+
concatenate_all_replay = False
|
283 |
+
|
284 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
285 |
+
|
286 |
+
def _forward(obs):
|
287 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
288 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
289 |
+
if cuda and torch.cuda.is_available():
|
290 |
+
obs = obs.cuda()
|
291 |
+
action = forward_fn(obs, mode='compute_actor')["action"]
|
292 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
293 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
294 |
+
return action
|
295 |
+
|
296 |
+
return _forward
|
297 |
+
|
298 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
299 |
+
|
300 |
+
# reset first to make sure the env is in the initial state
|
301 |
+
# env will be reset again in the main loop
|
302 |
+
env.reset()
|
303 |
+
|
304 |
+
for seed in seeds:
|
305 |
+
env.seed(seed, dynamic_seed=False)
|
306 |
+
return_ = 0.
|
307 |
+
step = 0
|
308 |
+
obs = env.reset()
|
309 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
310 |
+
while True:
|
311 |
+
action = forward_fn(obs)
|
312 |
+
obs, rew, done, info = env.step(action)
|
313 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
314 |
+
return_ += rew
|
315 |
+
step += 1
|
316 |
+
if done:
|
317 |
+
break
|
318 |
+
logging.info(f'DDPG deploy is finished, final episode return with {step} steps is: {return_}')
|
319 |
+
returns.append(return_)
|
320 |
+
|
321 |
+
env.close()
|
322 |
+
|
323 |
+
if concatenate_all_replay:
|
324 |
+
images = np.concatenate(images, axis=0)
|
325 |
+
import imageio
|
326 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
327 |
+
|
328 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
329 |
+
|
330 |
+
def collect_data(
|
331 |
+
self,
|
332 |
+
env_num: int = 8,
|
333 |
+
save_data_path: Optional[str] = None,
|
334 |
+
n_sample: Optional[int] = None,
|
335 |
+
n_episode: Optional[int] = None,
|
336 |
+
context: Optional[str] = None,
|
337 |
+
debug: bool = False
|
338 |
+
) -> None:
|
339 |
+
"""
|
340 |
+
Overview:
|
341 |
+
Collect data with DDPG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
342 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
343 |
+
``exp_name/demo_data``.
|
344 |
+
Arguments:
|
345 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
346 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
347 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
348 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
349 |
+
If not specified, ``n_episode`` must be specified.
|
350 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
351 |
+
If not specified, ``n_sample`` must be specified.
|
352 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
353 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
354 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
355 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
356 |
+
subprocess environment manager will be used.
|
357 |
+
"""
|
358 |
+
|
359 |
+
if debug:
|
360 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
361 |
+
if n_episode is not None:
|
362 |
+
raise NotImplementedError
|
363 |
+
# define env and policy
|
364 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
365 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
366 |
+
|
367 |
+
if save_data_path is None:
|
368 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
369 |
+
|
370 |
+
# main execution task
|
371 |
+
with task.start(ctx=OnlineRLContext()):
|
372 |
+
task.use(
|
373 |
+
StepCollector(
|
374 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
375 |
+
)
|
376 |
+
)
|
377 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
378 |
+
task.run(max_step=1)
|
379 |
+
logging.info(
|
380 |
+
f'DDPG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
381 |
+
)
|
382 |
+
|
383 |
+
def batch_evaluate(
|
384 |
+
self,
|
385 |
+
env_num: int = 4,
|
386 |
+
n_evaluator_episode: int = 4,
|
387 |
+
context: Optional[str] = None,
|
388 |
+
debug: bool = False
|
389 |
+
) -> EvalReturn:
|
390 |
+
"""
|
391 |
+
Overview:
|
392 |
+
Evaluate the agent with DDPG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
393 |
+
environments. The evaluation result will be returned.
|
394 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
395 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
396 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
397 |
+
Arguments:
|
398 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
399 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
400 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
401 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
402 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
403 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
404 |
+
subprocess environment manager will be used.
|
405 |
+
Returns:
|
406 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
407 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
408 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
409 |
+
"""
|
410 |
+
|
411 |
+
if debug:
|
412 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
413 |
+
# define env and policy
|
414 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
415 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
416 |
+
|
417 |
+
# reset first to make sure the env is in the initial state
|
418 |
+
# env will be reset again in the main loop
|
419 |
+
env.launch()
|
420 |
+
env.reset()
|
421 |
+
|
422 |
+
evaluate_cfg = self.cfg
|
423 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
424 |
+
|
425 |
+
# main execution task
|
426 |
+
with task.start(ctx=OnlineRLContext()):
|
427 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
428 |
+
task.run(max_step=1)
|
429 |
+
|
430 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
431 |
+
|
432 |
+
@property
|
433 |
+
def best(self) -> 'DDPGAgent':
|
434 |
+
"""
|
435 |
+
Overview:
|
436 |
+
Load the best model from the checkpoint directory, \
|
437 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
438 |
+
The return value is the agent with the best model.
|
439 |
+
Returns:
|
440 |
+
- (:obj:`DDPGAgent`): The agent with the best model.
|
441 |
+
Examples:
|
442 |
+
>>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
|
443 |
+
>>> agent.train()
|
444 |
+
>>> agent = agent.best
|
445 |
+
|
446 |
+
.. note::
|
447 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
448 |
+
model will be replaced by the best model.
|
449 |
+
"""
|
450 |
+
|
451 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
452 |
+
# Load best model if it exists
|
453 |
+
if os.path.exists(best_model_file_path):
|
454 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
455 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
456 |
+
return self
|
DI-engine/ding/bonus/dqn.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
|
11 |
+
OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import DQNPolicy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import DQN
|
19 |
+
from ding.model import model_wrap
|
20 |
+
from ding.data import DequeBuffer
|
21 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
22 |
+
from ding.config.example.DQN import supported_env_cfg
|
23 |
+
from ding.config.example.DQN import supported_env
|
24 |
+
|
25 |
+
|
26 |
+
class DQNAgent:
|
27 |
+
"""
|
28 |
+
Overview:
|
29 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Deep Q-Learning(DQN).
|
30 |
+
For more information about the system design of RL agent, please refer to \
|
31 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
32 |
+
Interface:
|
33 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
34 |
+
"""
|
35 |
+
supported_env_list = list(supported_env_cfg.keys())
|
36 |
+
"""
|
37 |
+
Overview:
|
38 |
+
List of supported envs.
|
39 |
+
Examples:
|
40 |
+
>>> from ding.bonus.dqn import DQNAgent
|
41 |
+
>>> print(DQNAgent.supported_env_list)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
env_id: str = None,
|
47 |
+
env: BaseEnv = None,
|
48 |
+
seed: int = 0,
|
49 |
+
exp_name: str = None,
|
50 |
+
model: Optional[torch.nn.Module] = None,
|
51 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
52 |
+
policy_state_dict: str = None,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Overview:
|
56 |
+
Initialize agent for DQN algorithm.
|
57 |
+
Arguments:
|
58 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
59 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
60 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
61 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
62 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
63 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
64 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
65 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
66 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
67 |
+
Default to 0.
|
68 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
69 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
70 |
+
- model (:obj:`torch.nn.Module`): The model of DQN algorithm, which should be an instance of class \
|
71 |
+
:class:`ding.model.DQN`. \
|
72 |
+
If not specified, a default model will be generated according to the configuration.
|
73 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of DQN algorithm, which is a dict. \
|
74 |
+
Default to None. If not specified, the default configuration will be used. \
|
75 |
+
The default configuration can be found in ``ding/config/example/DQN/gym_lunarlander_v2.py``.
|
76 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
77 |
+
If specified, the policy will be loaded from this file. Default to None.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
81 |
+
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
|
82 |
+
and we want to train an agent with DQN algorithm with default configuration. \
|
83 |
+
Then we can initialize the agent in the following ways:
|
84 |
+
>>> agent = DQNAgent(env_id='LunarLander-v2')
|
85 |
+
or, if we want can specify the env_id in the configuration:
|
86 |
+
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
|
87 |
+
>>> agent = DQNAgent(cfg=cfg)
|
88 |
+
There are also other arguments to specify the agent when initializing.
|
89 |
+
For example, if we want to specify the environment instance:
|
90 |
+
>>> env = CustomizedEnv('LunarLander-v2')
|
91 |
+
>>> agent = DQNAgent(cfg=cfg, env=env)
|
92 |
+
or, if we want to specify the model:
|
93 |
+
>>> model = DQN(**cfg.policy.model)
|
94 |
+
>>> agent = DQNAgent(cfg=cfg, model=model)
|
95 |
+
or, if we want to reload the policy from a saved policy state dict:
|
96 |
+
>>> agent = DQNAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
|
97 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
101 |
+
|
102 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
103 |
+
cfg = EasyDict(cfg)
|
104 |
+
|
105 |
+
if env_id is not None:
|
106 |
+
assert env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
|
107 |
+
DQNAgent.supported_env_list
|
108 |
+
)
|
109 |
+
if cfg is None:
|
110 |
+
cfg = supported_env_cfg[env_id]
|
111 |
+
else:
|
112 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
113 |
+
else:
|
114 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
115 |
+
assert cfg.env.env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
|
116 |
+
DQNAgent.supported_env_list
|
117 |
+
)
|
118 |
+
default_policy_config = EasyDict({"policy": DQNPolicy.default_config()})
|
119 |
+
default_policy_config.update(cfg)
|
120 |
+
cfg = default_policy_config
|
121 |
+
|
122 |
+
if exp_name is not None:
|
123 |
+
cfg.exp_name = exp_name
|
124 |
+
self.cfg = compile_config(cfg, policy=DQNPolicy)
|
125 |
+
self.exp_name = self.cfg.exp_name
|
126 |
+
if env is None:
|
127 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
128 |
+
else:
|
129 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
130 |
+
self.env = env
|
131 |
+
|
132 |
+
logging.getLogger().setLevel(logging.INFO)
|
133 |
+
self.seed = seed
|
134 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
135 |
+
if not os.path.exists(self.exp_name):
|
136 |
+
os.makedirs(self.exp_name)
|
137 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
138 |
+
if model is None:
|
139 |
+
model = DQN(**self.cfg.policy.model)
|
140 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
141 |
+
self.policy = DQNPolicy(self.cfg.policy, model=model)
|
142 |
+
if policy_state_dict is not None:
|
143 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
144 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
145 |
+
|
146 |
+
def train(
|
147 |
+
self,
|
148 |
+
step: int = int(1e7),
|
149 |
+
collector_env_num: int = None,
|
150 |
+
evaluator_env_num: int = None,
|
151 |
+
n_iter_save_ckpt: int = 1000,
|
152 |
+
context: Optional[str] = None,
|
153 |
+
debug: bool = False,
|
154 |
+
wandb_sweep: bool = False,
|
155 |
+
) -> TrainingReturn:
|
156 |
+
"""
|
157 |
+
Overview:
|
158 |
+
Train the agent with DQN algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
159 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
160 |
+
recorded and saved by wandb.
|
161 |
+
Arguments:
|
162 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
163 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
164 |
+
If not specified, it will be set according to the configuration.
|
165 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
166 |
+
If not specified, it will be set according to the configuration.
|
167 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
168 |
+
Default to 1000.
|
169 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
170 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
171 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
172 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
173 |
+
subprocess environment manager will be used.
|
174 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
175 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
176 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
177 |
+
Returns:
|
178 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
179 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if debug:
|
183 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
184 |
+
logging.debug(self.policy._model)
|
185 |
+
# define env and policy
|
186 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
187 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
188 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
189 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
190 |
+
|
191 |
+
with task.start(ctx=OnlineRLContext()):
|
192 |
+
task.use(
|
193 |
+
interaction_evaluator(
|
194 |
+
self.cfg,
|
195 |
+
self.policy.eval_mode,
|
196 |
+
evaluator_env,
|
197 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
198 |
+
)
|
199 |
+
)
|
200 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
201 |
+
task.use(eps_greedy_handler(self.cfg))
|
202 |
+
task.use(
|
203 |
+
StepCollector(
|
204 |
+
self.cfg,
|
205 |
+
self.policy.collect_mode,
|
206 |
+
collector_env,
|
207 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
208 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
209 |
+
)
|
210 |
+
)
|
211 |
+
if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
|
212 |
+
task.use(nstep_reward_enhancer(self.cfg))
|
213 |
+
task.use(data_pusher(self.cfg, self.buffer_))
|
214 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
215 |
+
task.use(
|
216 |
+
wandb_online_logger(
|
217 |
+
metric_list=self.policy._monitor_vars_learn(),
|
218 |
+
model=self.policy._model,
|
219 |
+
anonymous=True,
|
220 |
+
project_name=self.exp_name,
|
221 |
+
wandb_sweep=wandb_sweep,
|
222 |
+
)
|
223 |
+
)
|
224 |
+
task.use(termination_checker(max_env_step=step))
|
225 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
226 |
+
task.run()
|
227 |
+
|
228 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
229 |
+
|
230 |
+
def deploy(
|
231 |
+
self,
|
232 |
+
enable_save_replay: bool = False,
|
233 |
+
concatenate_all_replay: bool = False,
|
234 |
+
replay_save_path: str = None,
|
235 |
+
seed: Optional[Union[int, List]] = None,
|
236 |
+
debug: bool = False
|
237 |
+
) -> EvalReturn:
|
238 |
+
"""
|
239 |
+
Overview:
|
240 |
+
Deploy the agent with DQN algorithm by interacting with the environment, during which the replay video \
|
241 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
242 |
+
Arguments:
|
243 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
244 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
245 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
246 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
247 |
+
the replay video of each episode will be saved separately.
|
248 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
249 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
250 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
251 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
252 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
253 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
254 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
255 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
256 |
+
subprocess environment manager will be used.
|
257 |
+
Returns:
|
258 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
259 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
260 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
261 |
+
"""
|
262 |
+
|
263 |
+
if debug:
|
264 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
265 |
+
# define env and policy
|
266 |
+
env = self.env.clone(caller='evaluator')
|
267 |
+
|
268 |
+
if seed is not None and isinstance(seed, int):
|
269 |
+
seeds = [seed]
|
270 |
+
elif seed is not None and isinstance(seed, list):
|
271 |
+
seeds = seed
|
272 |
+
else:
|
273 |
+
seeds = [self.seed]
|
274 |
+
|
275 |
+
returns = []
|
276 |
+
images = []
|
277 |
+
if enable_save_replay:
|
278 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
279 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
280 |
+
else:
|
281 |
+
logging.warning('No video would be generated during the deploy.')
|
282 |
+
if concatenate_all_replay:
|
283 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
284 |
+
concatenate_all_replay = False
|
285 |
+
|
286 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
287 |
+
|
288 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
|
289 |
+
|
290 |
+
def _forward(obs):
|
291 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
292 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
293 |
+
if cuda and torch.cuda.is_available():
|
294 |
+
obs = obs.cuda()
|
295 |
+
action = forward_fn(obs)["action"]
|
296 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
297 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
298 |
+
return action
|
299 |
+
|
300 |
+
return _forward
|
301 |
+
|
302 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
303 |
+
|
304 |
+
# reset first to make sure the env is in the initial state
|
305 |
+
# env will be reset again in the main loop
|
306 |
+
env.reset()
|
307 |
+
|
308 |
+
for seed in seeds:
|
309 |
+
env.seed(seed, dynamic_seed=False)
|
310 |
+
return_ = 0.
|
311 |
+
step = 0
|
312 |
+
obs = env.reset()
|
313 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
314 |
+
while True:
|
315 |
+
action = forward_fn(obs)
|
316 |
+
obs, rew, done, info = env.step(action)
|
317 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
318 |
+
return_ += rew
|
319 |
+
step += 1
|
320 |
+
if done:
|
321 |
+
break
|
322 |
+
logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
|
323 |
+
returns.append(return_)
|
324 |
+
|
325 |
+
env.close()
|
326 |
+
|
327 |
+
if concatenate_all_replay:
|
328 |
+
images = np.concatenate(images, axis=0)
|
329 |
+
import imageio
|
330 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
331 |
+
|
332 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
333 |
+
|
334 |
+
def collect_data(
|
335 |
+
self,
|
336 |
+
env_num: int = 8,
|
337 |
+
save_data_path: Optional[str] = None,
|
338 |
+
n_sample: Optional[int] = None,
|
339 |
+
n_episode: Optional[int] = None,
|
340 |
+
context: Optional[str] = None,
|
341 |
+
debug: bool = False
|
342 |
+
) -> None:
|
343 |
+
"""
|
344 |
+
Overview:
|
345 |
+
Collect data with DQN algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
346 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
347 |
+
``exp_name/demo_data``.
|
348 |
+
Arguments:
|
349 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
350 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
351 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
352 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
353 |
+
If not specified, ``n_episode`` must be specified.
|
354 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
355 |
+
If not specified, ``n_sample`` must be specified.
|
356 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
357 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
358 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
359 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
360 |
+
subprocess environment manager will be used.
|
361 |
+
"""
|
362 |
+
|
363 |
+
if debug:
|
364 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
365 |
+
if n_episode is not None:
|
366 |
+
raise NotImplementedError
|
367 |
+
# define env and policy
|
368 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
369 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
370 |
+
|
371 |
+
if save_data_path is None:
|
372 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
373 |
+
|
374 |
+
# main execution task
|
375 |
+
with task.start(ctx=OnlineRLContext()):
|
376 |
+
task.use(
|
377 |
+
StepCollector(
|
378 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
379 |
+
)
|
380 |
+
)
|
381 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
382 |
+
task.run(max_step=1)
|
383 |
+
logging.info(
|
384 |
+
f'DQN collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
385 |
+
)
|
386 |
+
|
387 |
+
def batch_evaluate(
|
388 |
+
self,
|
389 |
+
env_num: int = 4,
|
390 |
+
n_evaluator_episode: int = 4,
|
391 |
+
context: Optional[str] = None,
|
392 |
+
debug: bool = False
|
393 |
+
) -> EvalReturn:
|
394 |
+
"""
|
395 |
+
Overview:
|
396 |
+
Evaluate the agent with DQN algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
397 |
+
environments. The evaluation result will be returned.
|
398 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
399 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
400 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
401 |
+
Arguments:
|
402 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
403 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
404 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
405 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
406 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
407 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
408 |
+
subprocess environment manager will be used.
|
409 |
+
Returns:
|
410 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
411 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
412 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
413 |
+
"""
|
414 |
+
|
415 |
+
if debug:
|
416 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
417 |
+
# define env and policy
|
418 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
419 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
420 |
+
|
421 |
+
# reset first to make sure the env is in the initial state
|
422 |
+
# env will be reset again in the main loop
|
423 |
+
env.launch()
|
424 |
+
env.reset()
|
425 |
+
|
426 |
+
evaluate_cfg = self.cfg
|
427 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
428 |
+
|
429 |
+
# main execution task
|
430 |
+
with task.start(ctx=OnlineRLContext()):
|
431 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
432 |
+
task.run(max_step=1)
|
433 |
+
|
434 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
435 |
+
|
436 |
+
@property
|
437 |
+
def best(self) -> 'DQNAgent':
|
438 |
+
"""
|
439 |
+
Overview:
|
440 |
+
Load the best model from the checkpoint directory, \
|
441 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
442 |
+
The return value is the agent with the best model.
|
443 |
+
Returns:
|
444 |
+
- (:obj:`DQNAgent`): The agent with the best model.
|
445 |
+
Examples:
|
446 |
+
>>> agent = DQNAgent(env_id='LunarLander-v2')
|
447 |
+
>>> agent.train()
|
448 |
+
>>> agent = agent.best
|
449 |
+
|
450 |
+
.. note::
|
451 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
452 |
+
model will be replaced by the best model.
|
453 |
+
"""
|
454 |
+
|
455 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
456 |
+
# Load best model if it exists
|
457 |
+
if os.path.exists(best_model_file_path):
|
458 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
459 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
460 |
+
return self
|
DI-engine/ding/bonus/model.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Optional
|
2 |
+
from easydict import EasyDict
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import treetensor.torch as ttorch
|
6 |
+
from copy import deepcopy
|
7 |
+
from ding.utils import SequenceType, squeeze
|
8 |
+
from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \
|
9 |
+
FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead
|
10 |
+
from ding.torch_utils import MLP, fc_block
|
11 |
+
|
12 |
+
|
13 |
+
class DiscretePolicyHead(nn.Module):
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
hidden_size: int,
|
18 |
+
output_size: int,
|
19 |
+
layer_num: int = 1,
|
20 |
+
activation: Optional[nn.Module] = nn.ReLU(),
|
21 |
+
norm_type: Optional[str] = None,
|
22 |
+
) -> None:
|
23 |
+
super(DiscretePolicyHead, self).__init__()
|
24 |
+
self.main = nn.Sequential(
|
25 |
+
MLP(
|
26 |
+
hidden_size,
|
27 |
+
hidden_size,
|
28 |
+
hidden_size,
|
29 |
+
layer_num,
|
30 |
+
layer_fn=nn.Linear,
|
31 |
+
activation=activation,
|
32 |
+
norm_type=norm_type
|
33 |
+
), fc_block(hidden_size, output_size)
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
37 |
+
return self.main(x)
|
38 |
+
|
39 |
+
|
40 |
+
class PPOFModel(nn.Module):
|
41 |
+
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
obs_shape: Union[int, SequenceType],
|
46 |
+
action_shape: Union[int, SequenceType, EasyDict],
|
47 |
+
action_space: str = 'discrete',
|
48 |
+
share_encoder: bool = True,
|
49 |
+
encoder_hidden_size_list: SequenceType = [128, 128, 64],
|
50 |
+
actor_head_hidden_size: int = 64,
|
51 |
+
actor_head_layer_num: int = 1,
|
52 |
+
critic_head_hidden_size: int = 64,
|
53 |
+
critic_head_layer_num: int = 1,
|
54 |
+
activation: Optional[nn.Module] = nn.ReLU(),
|
55 |
+
norm_type: Optional[str] = None,
|
56 |
+
sigma_type: Optional[str] = 'independent',
|
57 |
+
fixed_sigma_value: Optional[int] = 0.3,
|
58 |
+
bound_type: Optional[str] = None,
|
59 |
+
encoder: Optional[torch.nn.Module] = None,
|
60 |
+
popart_head=False,
|
61 |
+
) -> None:
|
62 |
+
super(PPOFModel, self).__init__()
|
63 |
+
obs_shape = squeeze(obs_shape)
|
64 |
+
action_shape = squeeze(action_shape)
|
65 |
+
self.obs_shape, self.action_shape = obs_shape, action_shape
|
66 |
+
self.share_encoder = share_encoder
|
67 |
+
|
68 |
+
# Encoder Type
|
69 |
+
def new_encoder(outsize):
|
70 |
+
if isinstance(obs_shape, int) or len(obs_shape) == 1:
|
71 |
+
return FCEncoder(
|
72 |
+
obs_shape=obs_shape,
|
73 |
+
hidden_size_list=encoder_hidden_size_list,
|
74 |
+
activation=activation,
|
75 |
+
norm_type=norm_type
|
76 |
+
)
|
77 |
+
elif len(obs_shape) == 3:
|
78 |
+
return ConvEncoder(
|
79 |
+
obs_shape=obs_shape,
|
80 |
+
hidden_size_list=encoder_hidden_size_list,
|
81 |
+
activation=activation,
|
82 |
+
norm_type=norm_type
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
raise RuntimeError(
|
86 |
+
"not support obs_shape for pre-defined encoder: {}, please customize your own encoder".
|
87 |
+
format(obs_shape)
|
88 |
+
)
|
89 |
+
|
90 |
+
if self.share_encoder:
|
91 |
+
assert actor_head_hidden_size == critic_head_hidden_size, \
|
92 |
+
"actor and critic network head should have same size."
|
93 |
+
if encoder:
|
94 |
+
if isinstance(encoder, torch.nn.Module):
|
95 |
+
self.encoder = encoder
|
96 |
+
else:
|
97 |
+
raise ValueError("illegal encoder instance.")
|
98 |
+
else:
|
99 |
+
self.encoder = new_encoder(actor_head_hidden_size)
|
100 |
+
else:
|
101 |
+
if encoder:
|
102 |
+
if isinstance(encoder, torch.nn.Module):
|
103 |
+
self.actor_encoder = encoder
|
104 |
+
self.critic_encoder = deepcopy(encoder)
|
105 |
+
else:
|
106 |
+
raise ValueError("illegal encoder instance.")
|
107 |
+
else:
|
108 |
+
self.actor_encoder = new_encoder(actor_head_hidden_size)
|
109 |
+
self.critic_encoder = new_encoder(critic_head_hidden_size)
|
110 |
+
|
111 |
+
# Head Type
|
112 |
+
if not popart_head:
|
113 |
+
self.critic_head = RegressionHead(
|
114 |
+
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
self.critic_head = PopArtVHead(
|
118 |
+
critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
|
119 |
+
)
|
120 |
+
|
121 |
+
self.action_space = action_space
|
122 |
+
assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
|
123 |
+
if self.action_space == 'continuous':
|
124 |
+
self.multi_head = False
|
125 |
+
self.actor_head = ReparameterizationHead(
|
126 |
+
actor_head_hidden_size,
|
127 |
+
action_shape,
|
128 |
+
actor_head_layer_num,
|
129 |
+
sigma_type=sigma_type,
|
130 |
+
activation=activation,
|
131 |
+
norm_type=norm_type,
|
132 |
+
bound_type=bound_type
|
133 |
+
)
|
134 |
+
elif self.action_space == 'discrete':
|
135 |
+
actor_head_cls = DiscretePolicyHead
|
136 |
+
multi_head = not isinstance(action_shape, int)
|
137 |
+
self.multi_head = multi_head
|
138 |
+
if multi_head:
|
139 |
+
self.actor_head = MultiHead(
|
140 |
+
actor_head_cls,
|
141 |
+
actor_head_hidden_size,
|
142 |
+
action_shape,
|
143 |
+
layer_num=actor_head_layer_num,
|
144 |
+
activation=activation,
|
145 |
+
norm_type=norm_type
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
self.actor_head = actor_head_cls(
|
149 |
+
actor_head_hidden_size,
|
150 |
+
action_shape,
|
151 |
+
actor_head_layer_num,
|
152 |
+
activation=activation,
|
153 |
+
norm_type=norm_type
|
154 |
+
)
|
155 |
+
elif self.action_space == 'hybrid': # HPPO
|
156 |
+
# hybrid action space: action_type(discrete) + action_args(continuous),
|
157 |
+
# such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
|
158 |
+
action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
|
159 |
+
action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
|
160 |
+
actor_action_args = ReparameterizationHead(
|
161 |
+
actor_head_hidden_size,
|
162 |
+
action_shape.action_args_shape,
|
163 |
+
actor_head_layer_num,
|
164 |
+
sigma_type=sigma_type,
|
165 |
+
fixed_sigma_value=fixed_sigma_value,
|
166 |
+
activation=activation,
|
167 |
+
norm_type=norm_type,
|
168 |
+
bound_type=bound_type,
|
169 |
+
)
|
170 |
+
actor_action_type = DiscretePolicyHead(
|
171 |
+
actor_head_hidden_size,
|
172 |
+
action_shape.action_type_shape,
|
173 |
+
actor_head_layer_num,
|
174 |
+
activation=activation,
|
175 |
+
norm_type=norm_type,
|
176 |
+
)
|
177 |
+
self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])
|
178 |
+
|
179 |
+
# must use list, not nn.ModuleList
|
180 |
+
if self.share_encoder:
|
181 |
+
self.actor = [self.encoder, self.actor_head]
|
182 |
+
self.critic = [self.encoder, self.critic_head]
|
183 |
+
else:
|
184 |
+
self.actor = [self.actor_encoder, self.actor_head]
|
185 |
+
self.critic = [self.critic_encoder, self.critic_head]
|
186 |
+
# Convenient for calling some apis (e.g. self.critic.parameters()),
|
187 |
+
# but may cause misunderstanding when `print(self)`
|
188 |
+
self.actor = nn.ModuleList(self.actor)
|
189 |
+
self.critic = nn.ModuleList(self.critic)
|
190 |
+
|
191 |
+
def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
|
192 |
+
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
|
193 |
+
return getattr(self, mode)(inputs)
|
194 |
+
|
195 |
+
def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
|
196 |
+
if self.share_encoder:
|
197 |
+
x = self.encoder(x)
|
198 |
+
else:
|
199 |
+
x = self.actor_encoder(x)
|
200 |
+
|
201 |
+
if self.action_space == 'discrete':
|
202 |
+
return self.actor_head(x)
|
203 |
+
elif self.action_space == 'continuous':
|
204 |
+
x = self.actor_head(x) # mu, sigma
|
205 |
+
return ttorch.as_tensor(x)
|
206 |
+
elif self.action_space == 'hybrid':
|
207 |
+
action_type = self.actor_head[0](x)
|
208 |
+
action_args = self.actor_head[1](x)
|
209 |
+
return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args})
|
210 |
+
|
211 |
+
def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
|
212 |
+
if self.share_encoder:
|
213 |
+
x = self.encoder(x)
|
214 |
+
else:
|
215 |
+
x = self.critic_encoder(x)
|
216 |
+
x = self.critic_head(x)
|
217 |
+
return x
|
218 |
+
|
219 |
+
def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
|
220 |
+
if self.share_encoder:
|
221 |
+
actor_embedding = critic_embedding = self.encoder(x)
|
222 |
+
else:
|
223 |
+
actor_embedding = self.actor_encoder(x)
|
224 |
+
critic_embedding = self.critic_encoder(x)
|
225 |
+
|
226 |
+
value = self.critic_head(critic_embedding)
|
227 |
+
|
228 |
+
if self.action_space == 'discrete':
|
229 |
+
logit = self.actor_head(actor_embedding)
|
230 |
+
return ttorch.as_tensor({'logit': logit, 'value': value['pred']})
|
231 |
+
elif self.action_space == 'continuous':
|
232 |
+
x = self.actor_head(actor_embedding)
|
233 |
+
return ttorch.as_tensor({'logit': x, 'value': value['pred']})
|
234 |
+
elif self.action_space == 'hybrid':
|
235 |
+
action_type = self.actor_head[0](actor_embedding)
|
236 |
+
action_args = self.actor_head[1](actor_embedding)
|
237 |
+
return ttorch.as_tensor(
|
238 |
+
{
|
239 |
+
'logit': {
|
240 |
+
'action_type': action_type,
|
241 |
+
'action_args': action_args
|
242 |
+
},
|
243 |
+
'value': value['pred']
|
244 |
+
}
|
245 |
+
)
|
DI-engine/ding/bonus/pg.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, trainer, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
|
11 |
+
montecarlo_return_estimator, final_ctx_saver, EpisodeCollector
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import PGPolicy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import PG
|
19 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
20 |
+
from ding.config.example.PG import supported_env_cfg
|
21 |
+
from ding.config.example.PG import supported_env
|
22 |
+
|
23 |
+
|
24 |
+
class PGAgent:
|
25 |
+
"""
|
26 |
+
Overview:
|
27 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Policy Gradient(PG).
|
28 |
+
For more information about the system design of RL agent, please refer to \
|
29 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
30 |
+
Interface:
|
31 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
32 |
+
"""
|
33 |
+
supported_env_list = list(supported_env_cfg.keys())
|
34 |
+
"""
|
35 |
+
Overview:
|
36 |
+
List of supported envs.
|
37 |
+
Examples:
|
38 |
+
>>> from ding.bonus.pg import PGAgent
|
39 |
+
>>> print(PGAgent.supported_env_list)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
env_id: str = None,
|
45 |
+
env: BaseEnv = None,
|
46 |
+
seed: int = 0,
|
47 |
+
exp_name: str = None,
|
48 |
+
model: Optional[torch.nn.Module] = None,
|
49 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
50 |
+
policy_state_dict: str = None,
|
51 |
+
) -> None:
|
52 |
+
"""
|
53 |
+
Overview:
|
54 |
+
Initialize agent for PG algorithm.
|
55 |
+
Arguments:
|
56 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
57 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
58 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
59 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
60 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
61 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
62 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
63 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
64 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
65 |
+
Default to 0.
|
66 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
67 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
68 |
+
- model (:obj:`torch.nn.Module`): The model of PG algorithm, which should be an instance of class \
|
69 |
+
:class:`ding.model.PG`. \
|
70 |
+
If not specified, a default model will be generated according to the configuration.
|
71 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of PG algorithm, which is a dict. \
|
72 |
+
Default to None. If not specified, the default configuration will be used. \
|
73 |
+
The default configuration can be found in ``ding/config/example/PG/gym_lunarlander_v2.py``.
|
74 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
75 |
+
If specified, the policy will be loaded from this file. Default to None.
|
76 |
+
|
77 |
+
.. note::
|
78 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
79 |
+
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
|
80 |
+
and we want to train an agent with PG algorithm with default configuration. \
|
81 |
+
Then we can initialize the agent in the following ways:
|
82 |
+
>>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
|
83 |
+
or, if we want can specify the env_id in the configuration:
|
84 |
+
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
|
85 |
+
>>> agent = PGAgent(cfg=cfg)
|
86 |
+
There are also other arguments to specify the agent when initializing.
|
87 |
+
For example, if we want to specify the environment instance:
|
88 |
+
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
|
89 |
+
>>> agent = PGAgent(cfg=cfg, env=env)
|
90 |
+
or, if we want to specify the model:
|
91 |
+
>>> model = PG(**cfg.policy.model)
|
92 |
+
>>> agent = PGAgent(cfg=cfg, model=model)
|
93 |
+
or, if we want to reload the policy from a saved policy state dict:
|
94 |
+
>>> agent = PGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
|
95 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
96 |
+
"""
|
97 |
+
|
98 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
99 |
+
|
100 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
101 |
+
cfg = EasyDict(cfg)
|
102 |
+
|
103 |
+
if env_id is not None:
|
104 |
+
assert env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
|
105 |
+
PGAgent.supported_env_list
|
106 |
+
)
|
107 |
+
if cfg is None:
|
108 |
+
cfg = supported_env_cfg[env_id]
|
109 |
+
else:
|
110 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
111 |
+
else:
|
112 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
113 |
+
assert cfg.env.env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
|
114 |
+
PGAgent.supported_env_list
|
115 |
+
)
|
116 |
+
default_policy_config = EasyDict({"policy": PGPolicy.default_config()})
|
117 |
+
default_policy_config.update(cfg)
|
118 |
+
cfg = default_policy_config
|
119 |
+
|
120 |
+
if exp_name is not None:
|
121 |
+
cfg.exp_name = exp_name
|
122 |
+
self.cfg = compile_config(cfg, policy=PGPolicy)
|
123 |
+
self.exp_name = self.cfg.exp_name
|
124 |
+
if env is None:
|
125 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
126 |
+
else:
|
127 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
128 |
+
self.env = env
|
129 |
+
|
130 |
+
logging.getLogger().setLevel(logging.INFO)
|
131 |
+
self.seed = seed
|
132 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
133 |
+
if not os.path.exists(self.exp_name):
|
134 |
+
os.makedirs(self.exp_name)
|
135 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
136 |
+
if model is None:
|
137 |
+
model = PG(**self.cfg.policy.model)
|
138 |
+
self.policy = PGPolicy(self.cfg.policy, model=model)
|
139 |
+
if policy_state_dict is not None:
|
140 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
141 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
142 |
+
|
143 |
+
def train(
|
144 |
+
self,
|
145 |
+
step: int = int(1e7),
|
146 |
+
collector_env_num: int = None,
|
147 |
+
evaluator_env_num: int = None,
|
148 |
+
n_iter_save_ckpt: int = 1000,
|
149 |
+
context: Optional[str] = None,
|
150 |
+
debug: bool = False,
|
151 |
+
wandb_sweep: bool = False,
|
152 |
+
) -> TrainingReturn:
|
153 |
+
"""
|
154 |
+
Overview:
|
155 |
+
Train the agent with PG algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
156 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
157 |
+
recorded and saved by wandb.
|
158 |
+
Arguments:
|
159 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
160 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
161 |
+
If not specified, it will be set according to the configuration.
|
162 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
163 |
+
If not specified, it will be set according to the configuration.
|
164 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
165 |
+
Default to 1000.
|
166 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
167 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
168 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
169 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
170 |
+
subprocess environment manager will be used.
|
171 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
172 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
173 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
174 |
+
Returns:
|
175 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
176 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
177 |
+
"""
|
178 |
+
|
179 |
+
if debug:
|
180 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
181 |
+
logging.debug(self.policy._model)
|
182 |
+
# define env and policy
|
183 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
184 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
185 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
186 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
187 |
+
|
188 |
+
with task.start(ctx=OnlineRLContext()):
|
189 |
+
task.use(
|
190 |
+
interaction_evaluator(
|
191 |
+
self.cfg,
|
192 |
+
self.policy.eval_mode,
|
193 |
+
evaluator_env,
|
194 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
195 |
+
)
|
196 |
+
)
|
197 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
198 |
+
task.use(EpisodeCollector(self.cfg, self.policy.collect_mode, collector_env))
|
199 |
+
task.use(montecarlo_return_estimator(self.policy))
|
200 |
+
task.use(trainer(self.cfg, self.policy.learn_mode))
|
201 |
+
task.use(
|
202 |
+
wandb_online_logger(
|
203 |
+
metric_list=self.policy._monitor_vars_learn(),
|
204 |
+
model=self.policy._model,
|
205 |
+
anonymous=True,
|
206 |
+
project_name=self.exp_name,
|
207 |
+
wandb_sweep=wandb_sweep,
|
208 |
+
)
|
209 |
+
)
|
210 |
+
task.use(termination_checker(max_env_step=step))
|
211 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
212 |
+
task.run()
|
213 |
+
|
214 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
215 |
+
|
216 |
+
def deploy(
|
217 |
+
self,
|
218 |
+
enable_save_replay: bool = False,
|
219 |
+
concatenate_all_replay: bool = False,
|
220 |
+
replay_save_path: str = None,
|
221 |
+
seed: Optional[Union[int, List]] = None,
|
222 |
+
debug: bool = False
|
223 |
+
) -> EvalReturn:
|
224 |
+
"""
|
225 |
+
Overview:
|
226 |
+
Deploy the agent with PG algorithm by interacting with the environment, during which the replay video \
|
227 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
228 |
+
Arguments:
|
229 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
230 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
231 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
232 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
233 |
+
the replay video of each episode will be saved separately.
|
234 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
235 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
236 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
237 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
238 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
239 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
240 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
241 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
242 |
+
subprocess environment manager will be used.
|
243 |
+
Returns:
|
244 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
245 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
246 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
247 |
+
"""
|
248 |
+
|
249 |
+
if debug:
|
250 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
251 |
+
# define env and policy
|
252 |
+
env = self.env.clone(caller='evaluator')
|
253 |
+
|
254 |
+
if seed is not None and isinstance(seed, int):
|
255 |
+
seeds = [seed]
|
256 |
+
elif seed is not None and isinstance(seed, list):
|
257 |
+
seeds = seed
|
258 |
+
else:
|
259 |
+
seeds = [self.seed]
|
260 |
+
|
261 |
+
returns = []
|
262 |
+
images = []
|
263 |
+
if enable_save_replay:
|
264 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
265 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
266 |
+
else:
|
267 |
+
logging.warning('No video would be generated during the deploy.')
|
268 |
+
if concatenate_all_replay:
|
269 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
270 |
+
concatenate_all_replay = False
|
271 |
+
|
272 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
273 |
+
|
274 |
+
def _forward(obs):
|
275 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
276 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
277 |
+
if cuda and torch.cuda.is_available():
|
278 |
+
obs = obs.cuda()
|
279 |
+
output = forward_fn(obs)
|
280 |
+
if self.policy._cfg.deterministic_eval:
|
281 |
+
if self.policy._cfg.action_space == 'discrete':
|
282 |
+
output['action'] = output['logit'].argmax(dim=-1)
|
283 |
+
elif self.policy._cfg.action_space == 'continuous':
|
284 |
+
output['action'] = output['logit']['mu']
|
285 |
+
else:
|
286 |
+
raise KeyError("invalid action_space: {}".format(self.policy._cfg.action_space))
|
287 |
+
else:
|
288 |
+
output['action'] = output['dist'].sample()
|
289 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
290 |
+
action = output['action'].squeeze(0).detach().cpu().numpy()
|
291 |
+
return action
|
292 |
+
|
293 |
+
return _forward
|
294 |
+
|
295 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
296 |
+
|
297 |
+
# reset first to make sure the env is in the initial state
|
298 |
+
# env will be reset again in the main loop
|
299 |
+
env.reset()
|
300 |
+
|
301 |
+
for seed in seeds:
|
302 |
+
env.seed(seed, dynamic_seed=False)
|
303 |
+
return_ = 0.
|
304 |
+
step = 0
|
305 |
+
obs = env.reset()
|
306 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
307 |
+
while True:
|
308 |
+
action = forward_fn(obs)
|
309 |
+
obs, rew, done, info = env.step(action)
|
310 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
311 |
+
return_ += rew
|
312 |
+
step += 1
|
313 |
+
if done:
|
314 |
+
break
|
315 |
+
logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
|
316 |
+
returns.append(return_)
|
317 |
+
|
318 |
+
env.close()
|
319 |
+
|
320 |
+
if concatenate_all_replay:
|
321 |
+
images = np.concatenate(images, axis=0)
|
322 |
+
import imageio
|
323 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
324 |
+
|
325 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
326 |
+
|
327 |
+
def collect_data(
|
328 |
+
self,
|
329 |
+
env_num: int = 8,
|
330 |
+
save_data_path: Optional[str] = None,
|
331 |
+
n_sample: Optional[int] = None,
|
332 |
+
n_episode: Optional[int] = None,
|
333 |
+
context: Optional[str] = None,
|
334 |
+
debug: bool = False
|
335 |
+
) -> None:
|
336 |
+
"""
|
337 |
+
Overview:
|
338 |
+
Collect data with PG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
339 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
340 |
+
``exp_name/demo_data``.
|
341 |
+
Arguments:
|
342 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
343 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
344 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
345 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
346 |
+
If not specified, ``n_episode`` must be specified.
|
347 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
348 |
+
If not specified, ``n_sample`` must be specified.
|
349 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
350 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
351 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
352 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
353 |
+
subprocess environment manager will be used.
|
354 |
+
"""
|
355 |
+
|
356 |
+
if debug:
|
357 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
358 |
+
if n_episode is not None:
|
359 |
+
raise NotImplementedError
|
360 |
+
# define env and policy
|
361 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
362 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
363 |
+
|
364 |
+
if save_data_path is None:
|
365 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
366 |
+
|
367 |
+
# main execution task
|
368 |
+
with task.start(ctx=OnlineRLContext()):
|
369 |
+
task.use(
|
370 |
+
StepCollector(
|
371 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
372 |
+
)
|
373 |
+
)
|
374 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
375 |
+
task.run(max_step=1)
|
376 |
+
logging.info(
|
377 |
+
f'PG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
378 |
+
)
|
379 |
+
|
380 |
+
def batch_evaluate(
|
381 |
+
self,
|
382 |
+
env_num: int = 4,
|
383 |
+
n_evaluator_episode: int = 4,
|
384 |
+
context: Optional[str] = None,
|
385 |
+
debug: bool = False
|
386 |
+
) -> EvalReturn:
|
387 |
+
"""
|
388 |
+
Overview:
|
389 |
+
Evaluate the agent with PG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
390 |
+
environments. The evaluation result will be returned.
|
391 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
392 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
393 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
394 |
+
Arguments:
|
395 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
396 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
397 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
398 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
399 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
400 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
401 |
+
subprocess environment manager will be used.
|
402 |
+
Returns:
|
403 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
404 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
405 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
406 |
+
"""
|
407 |
+
|
408 |
+
if debug:
|
409 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
410 |
+
# define env and policy
|
411 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
412 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
413 |
+
|
414 |
+
# reset first to make sure the env is in the initial state
|
415 |
+
# env will be reset again in the main loop
|
416 |
+
env.launch()
|
417 |
+
env.reset()
|
418 |
+
|
419 |
+
evaluate_cfg = self.cfg
|
420 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
421 |
+
|
422 |
+
# main execution task
|
423 |
+
with task.start(ctx=OnlineRLContext()):
|
424 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
425 |
+
task.run(max_step=1)
|
426 |
+
|
427 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
428 |
+
|
429 |
+
@property
|
430 |
+
def best(self) -> 'PGAgent':
|
431 |
+
"""
|
432 |
+
Overview:
|
433 |
+
Load the best model from the checkpoint directory, \
|
434 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
435 |
+
The return value is the agent with the best model.
|
436 |
+
Returns:
|
437 |
+
- (:obj:`PGAgent`): The agent with the best model.
|
438 |
+
Examples:
|
439 |
+
>>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
|
440 |
+
>>> agent.train()
|
441 |
+
>>> agent = agent.best
|
442 |
+
|
443 |
+
.. note::
|
444 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
445 |
+
model will be replaced by the best model.
|
446 |
+
"""
|
447 |
+
|
448 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
449 |
+
# Load best model if it exists
|
450 |
+
if os.path.exists(best_model_file_path):
|
451 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
452 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
453 |
+
return self
|
DI-engine/ding/bonus/ppo_offpolicy.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, final_ctx_saver, OffPolicyLearner, StepCollector, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, gae_estimator
|
11 |
+
from ding.envs import BaseEnv
|
12 |
+
from ding.envs import setup_ding_env_manager
|
13 |
+
from ding.policy import PPOOffPolicy
|
14 |
+
from ding.utils import set_pkg_seed
|
15 |
+
from ding.utils import get_env_fps, render
|
16 |
+
from ding.config import save_config_py, compile_config
|
17 |
+
from ding.model import VAC
|
18 |
+
from ding.model import model_wrap
|
19 |
+
from ding.data import DequeBuffer
|
20 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
21 |
+
from ding.config.example.PPOOffPolicy import supported_env_cfg
|
22 |
+
from ding.config.example.PPOOffPolicy import supported_env
|
23 |
+
|
24 |
+
|
25 |
+
class PPOOffPolicyAgent:
|
26 |
+
"""
|
27 |
+
Overview:
|
28 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
29 |
+
Proximal Policy Optimization(PPO) in an off-policy style.
|
30 |
+
For more information about the system design of RL agent, please refer to \
|
31 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
32 |
+
Interface:
|
33 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
34 |
+
"""
|
35 |
+
supported_env_list = list(supported_env_cfg.keys())
|
36 |
+
"""
|
37 |
+
Overview:
|
38 |
+
List of supported envs.
|
39 |
+
Examples:
|
40 |
+
>>> from ding.bonus.ppo_offpolicy import PPOOffPolicyAgent
|
41 |
+
>>> print(PPOOffPolicyAgent.supported_env_list)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
env_id: str = None,
|
47 |
+
env: BaseEnv = None,
|
48 |
+
seed: int = 0,
|
49 |
+
exp_name: str = None,
|
50 |
+
model: Optional[torch.nn.Module] = None,
|
51 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
52 |
+
policy_state_dict: str = None,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Overview:
|
56 |
+
Initialize agent for PPO (offpolicy) algorithm.
|
57 |
+
Arguments:
|
58 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
59 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
60 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
61 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
62 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
63 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
64 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
65 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
66 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
67 |
+
Default to 0.
|
68 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
69 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
70 |
+
- model (:obj:`torch.nn.Module`): The model of PPO (offpolicy) algorithm, \
|
71 |
+
which should be an instance of class :class:`ding.model.VAC`. \
|
72 |
+
If not specified, a default model will be generated according to the configuration.
|
73 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of PPO (offpolicy) algorithm, which is a dict. \
|
74 |
+
Default to None. If not specified, the default configuration will be used. \
|
75 |
+
The default configuration can be found in ``ding/config/example/PPO (offpolicy)/gym_lunarlander_v2.py``.
|
76 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
77 |
+
If specified, the policy will be loaded from this file. Default to None.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
81 |
+
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
|
82 |
+
and we want to train an agent with PPO (offpolicy) algorithm with default configuration. \
|
83 |
+
Then we can initialize the agent in the following ways:
|
84 |
+
>>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
|
85 |
+
or, if we want can specify the env_id in the configuration:
|
86 |
+
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
|
87 |
+
>>> agent = PPOOffPolicyAgent(cfg=cfg)
|
88 |
+
There are also other arguments to specify the agent when initializing.
|
89 |
+
For example, if we want to specify the environment instance:
|
90 |
+
>>> env = CustomizedEnv('LunarLander-v2')
|
91 |
+
>>> agent = PPOOffPolicyAgent(cfg=cfg, env=env)
|
92 |
+
or, if we want to specify the model:
|
93 |
+
>>> model = VAC(**cfg.policy.model)
|
94 |
+
>>> agent = PPOOffPolicyAgent(cfg=cfg, model=model)
|
95 |
+
or, if we want to reload the policy from a saved policy state dict:
|
96 |
+
>>> agent = PPOOffPolicyAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
|
97 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
101 |
+
|
102 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
103 |
+
cfg = EasyDict(cfg)
|
104 |
+
|
105 |
+
if env_id is not None:
|
106 |
+
assert env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
|
107 |
+
PPOOffPolicyAgent.supported_env_list
|
108 |
+
)
|
109 |
+
if cfg is None:
|
110 |
+
cfg = supported_env_cfg[env_id]
|
111 |
+
else:
|
112 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
113 |
+
else:
|
114 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
115 |
+
assert cfg.env.env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
|
116 |
+
PPOOffPolicyAgent.supported_env_list
|
117 |
+
)
|
118 |
+
default_policy_config = EasyDict({"policy": PPOOffPolicy.default_config()})
|
119 |
+
default_policy_config.update(cfg)
|
120 |
+
cfg = default_policy_config
|
121 |
+
|
122 |
+
if exp_name is not None:
|
123 |
+
cfg.exp_name = exp_name
|
124 |
+
self.cfg = compile_config(cfg, policy=PPOOffPolicy)
|
125 |
+
self.exp_name = self.cfg.exp_name
|
126 |
+
if env is None:
|
127 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
128 |
+
else:
|
129 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
130 |
+
self.env = env
|
131 |
+
|
132 |
+
logging.getLogger().setLevel(logging.INFO)
|
133 |
+
self.seed = seed
|
134 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
135 |
+
if not os.path.exists(self.exp_name):
|
136 |
+
os.makedirs(self.exp_name)
|
137 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
138 |
+
if model is None:
|
139 |
+
model = VAC(**self.cfg.policy.model)
|
140 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
141 |
+
self.policy = PPOOffPolicy(self.cfg.policy, model=model)
|
142 |
+
if policy_state_dict is not None:
|
143 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
144 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
145 |
+
|
146 |
+
def train(
|
147 |
+
self,
|
148 |
+
step: int = int(1e7),
|
149 |
+
collector_env_num: int = None,
|
150 |
+
evaluator_env_num: int = None,
|
151 |
+
n_iter_save_ckpt: int = 1000,
|
152 |
+
context: Optional[str] = None,
|
153 |
+
debug: bool = False,
|
154 |
+
wandb_sweep: bool = False,
|
155 |
+
) -> TrainingReturn:
|
156 |
+
"""
|
157 |
+
Overview:
|
158 |
+
Train the agent with PPO (offpolicy) algorithm for ``step`` iterations with ``collector_env_num`` \
|
159 |
+
collector environments and ``evaluator_env_num`` evaluator environments. \
|
160 |
+
Information during training will be recorded and saved by wandb.
|
161 |
+
Arguments:
|
162 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
163 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
164 |
+
If not specified, it will be set according to the configuration.
|
165 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
166 |
+
If not specified, it will be set according to the configuration.
|
167 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
168 |
+
Default to 1000.
|
169 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
170 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
171 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
172 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
173 |
+
subprocess environment manager will be used.
|
174 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
175 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
176 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
177 |
+
Returns:
|
178 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
179 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if debug:
|
183 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
184 |
+
logging.debug(self.policy._model)
|
185 |
+
# define env and policy
|
186 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
187 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
188 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
189 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
190 |
+
|
191 |
+
with task.start(ctx=OnlineRLContext()):
|
192 |
+
task.use(
|
193 |
+
interaction_evaluator(
|
194 |
+
self.cfg,
|
195 |
+
self.policy.eval_mode,
|
196 |
+
evaluator_env,
|
197 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
198 |
+
)
|
199 |
+
)
|
200 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
201 |
+
task.use(
|
202 |
+
StepCollector(
|
203 |
+
self.cfg,
|
204 |
+
self.policy.collect_mode,
|
205 |
+
collector_env,
|
206 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
207 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
208 |
+
)
|
209 |
+
)
|
210 |
+
task.use(gae_estimator(self.cfg, self.policy.collect_mode, self.buffer_))
|
211 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
212 |
+
task.use(
|
213 |
+
wandb_online_logger(
|
214 |
+
cfg=self.cfg.wandb_logger,
|
215 |
+
exp_config=self.cfg,
|
216 |
+
metric_list=self.policy._monitor_vars_learn(),
|
217 |
+
model=self.policy._model,
|
218 |
+
anonymous=True,
|
219 |
+
project_name=self.exp_name,
|
220 |
+
wandb_sweep=wandb_sweep,
|
221 |
+
)
|
222 |
+
)
|
223 |
+
task.use(termination_checker(max_env_step=step))
|
224 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
225 |
+
task.run()
|
226 |
+
|
227 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
228 |
+
|
229 |
+
def deploy(
|
230 |
+
self,
|
231 |
+
enable_save_replay: bool = False,
|
232 |
+
concatenate_all_replay: bool = False,
|
233 |
+
replay_save_path: str = None,
|
234 |
+
seed: Optional[Union[int, List]] = None,
|
235 |
+
debug: bool = False
|
236 |
+
) -> EvalReturn:
|
237 |
+
"""
|
238 |
+
Overview:
|
239 |
+
Deploy the agent with PPO (offpolicy) algorithm by interacting with the environment, \
|
240 |
+
during which the replay video can be saved if ``enable_save_replay`` is True. \
|
241 |
+
The evaluation result will be returned.
|
242 |
+
Arguments:
|
243 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
244 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
245 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
246 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
247 |
+
the replay video of each episode will be saved separately.
|
248 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
249 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
250 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
251 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
252 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
253 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
254 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
255 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
256 |
+
subprocess environment manager will be used.
|
257 |
+
Returns:
|
258 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
259 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
260 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
261 |
+
"""
|
262 |
+
|
263 |
+
if debug:
|
264 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
265 |
+
# define env and policy
|
266 |
+
env = self.env.clone(caller='evaluator')
|
267 |
+
|
268 |
+
if seed is not None and isinstance(seed, int):
|
269 |
+
seeds = [seed]
|
270 |
+
elif seed is not None and isinstance(seed, list):
|
271 |
+
seeds = seed
|
272 |
+
else:
|
273 |
+
seeds = [self.seed]
|
274 |
+
|
275 |
+
returns = []
|
276 |
+
images = []
|
277 |
+
if enable_save_replay:
|
278 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
279 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
280 |
+
else:
|
281 |
+
logging.warning('No video would be generated during the deploy.')
|
282 |
+
if concatenate_all_replay:
|
283 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
284 |
+
concatenate_all_replay = False
|
285 |
+
|
286 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
287 |
+
|
288 |
+
if self.cfg.policy.action_space == 'discrete':
|
289 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
|
290 |
+
elif self.cfg.policy.action_space == 'continuous':
|
291 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
|
292 |
+
elif self.cfg.policy.action_space == 'hybrid':
|
293 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='hybrid_deterministic_argmax_sample').forward
|
294 |
+
elif self.cfg.policy.action_space == 'general':
|
295 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
|
296 |
+
else:
|
297 |
+
raise NotImplementedError
|
298 |
+
|
299 |
+
def _forward(obs):
|
300 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
301 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
302 |
+
if cuda and torch.cuda.is_available():
|
303 |
+
obs = obs.cuda()
|
304 |
+
action = forward_fn(obs, mode='compute_actor')["action"]
|
305 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
306 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
307 |
+
return action
|
308 |
+
|
309 |
+
return _forward
|
310 |
+
|
311 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
312 |
+
|
313 |
+
# reset first to make sure the env is in the initial state
|
314 |
+
# env will be reset again in the main loop
|
315 |
+
env.reset()
|
316 |
+
|
317 |
+
for seed in seeds:
|
318 |
+
env.seed(seed, dynamic_seed=False)
|
319 |
+
return_ = 0.
|
320 |
+
step = 0
|
321 |
+
obs = env.reset()
|
322 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
323 |
+
while True:
|
324 |
+
action = forward_fn(obs)
|
325 |
+
obs, rew, done, info = env.step(action)
|
326 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
327 |
+
return_ += rew
|
328 |
+
step += 1
|
329 |
+
if done:
|
330 |
+
break
|
331 |
+
logging.info(f'PPO (offpolicy) deploy is finished, final episode return with {step} steps is: {return_}')
|
332 |
+
returns.append(return_)
|
333 |
+
|
334 |
+
env.close()
|
335 |
+
|
336 |
+
if concatenate_all_replay:
|
337 |
+
images = np.concatenate(images, axis=0)
|
338 |
+
import imageio
|
339 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
340 |
+
|
341 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
342 |
+
|
343 |
+
def collect_data(
|
344 |
+
self,
|
345 |
+
env_num: int = 8,
|
346 |
+
save_data_path: Optional[str] = None,
|
347 |
+
n_sample: Optional[int] = None,
|
348 |
+
n_episode: Optional[int] = None,
|
349 |
+
context: Optional[str] = None,
|
350 |
+
debug: bool = False
|
351 |
+
) -> None:
|
352 |
+
"""
|
353 |
+
Overview:
|
354 |
+
Collect data with PPO (offpolicy) algorithm for ``n_episode`` episodes \
|
355 |
+
with ``env_num`` collector environments. \
|
356 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
357 |
+
``exp_name/demo_data``.
|
358 |
+
Arguments:
|
359 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
360 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
361 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
362 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
363 |
+
If not specified, ``n_episode`` must be specified.
|
364 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
365 |
+
If not specified, ``n_sample`` must be specified.
|
366 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
367 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
368 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
369 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
370 |
+
subprocess environment manager will be used.
|
371 |
+
"""
|
372 |
+
|
373 |
+
if debug:
|
374 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
375 |
+
if n_episode is not None:
|
376 |
+
raise NotImplementedError
|
377 |
+
# define env and policy
|
378 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
379 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
380 |
+
|
381 |
+
if save_data_path is None:
|
382 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
383 |
+
|
384 |
+
# main execution task
|
385 |
+
with task.start(ctx=OnlineRLContext()):
|
386 |
+
task.use(
|
387 |
+
StepCollector(
|
388 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
389 |
+
)
|
390 |
+
)
|
391 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
392 |
+
task.run(max_step=1)
|
393 |
+
logging.info(
|
394 |
+
f'PPOOffPolicy collecting is finished, more than {n_sample} \
|
395 |
+
samples are collected and saved in `{save_data_path}`'
|
396 |
+
)
|
397 |
+
|
398 |
+
def batch_evaluate(
|
399 |
+
self,
|
400 |
+
env_num: int = 4,
|
401 |
+
n_evaluator_episode: int = 4,
|
402 |
+
context: Optional[str] = None,
|
403 |
+
debug: bool = False
|
404 |
+
) -> EvalReturn:
|
405 |
+
"""
|
406 |
+
Overview:
|
407 |
+
Evaluate the agent with PPO (offpolicy) algorithm for ``n_evaluator_episode`` episodes \
|
408 |
+
with ``env_num`` evaluator environments. The evaluation result will be returned.
|
409 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
410 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
411 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
412 |
+
Arguments:
|
413 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
414 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
415 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
416 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
417 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
418 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
419 |
+
subprocess environment manager will be used.
|
420 |
+
Returns:
|
421 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
422 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
423 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
424 |
+
"""
|
425 |
+
|
426 |
+
if debug:
|
427 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
428 |
+
# define env and policy
|
429 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
430 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
431 |
+
|
432 |
+
# reset first to make sure the env is in the initial state
|
433 |
+
# env will be reset again in the main loop
|
434 |
+
env.launch()
|
435 |
+
env.reset()
|
436 |
+
|
437 |
+
evaluate_cfg = self.cfg
|
438 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
439 |
+
|
440 |
+
# main execution task
|
441 |
+
with task.start(ctx=OnlineRLContext()):
|
442 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
443 |
+
task.run(max_step=1)
|
444 |
+
|
445 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
446 |
+
|
447 |
+
@property
|
448 |
+
def best(self) -> 'PPOOffPolicyAgent':
|
449 |
+
"""
|
450 |
+
Overview:
|
451 |
+
Load the best model from the checkpoint directory, \
|
452 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
453 |
+
The return value is the agent with the best model.
|
454 |
+
Returns:
|
455 |
+
- (:obj:`PPOOffPolicyAgent`): The agent with the best model.
|
456 |
+
Examples:
|
457 |
+
>>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
|
458 |
+
>>> agent.train()
|
459 |
+
>>> agent.best
|
460 |
+
|
461 |
+
.. note::
|
462 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
463 |
+
model will be replaced by the best model.
|
464 |
+
"""
|
465 |
+
|
466 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
467 |
+
# Load best model if it exists
|
468 |
+
if os.path.exists(best_model_file_path):
|
469 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
470 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
471 |
+
return self
|
DI-engine/ding/bonus/ppof.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
from functools import partial
|
5 |
+
import os
|
6 |
+
import gym
|
7 |
+
import gymnasium
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from ding.framework import task, OnlineRLContext
|
11 |
+
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
|
12 |
+
wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
|
13 |
+
from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
|
14 |
+
from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py
|
18 |
+
from .model import PPOFModel
|
19 |
+
from .config import get_instance_config, get_instance_env, get_hybrid_shape
|
20 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
21 |
+
|
22 |
+
|
23 |
+
class PPOF:
|
24 |
+
"""
|
25 |
+
Overview:
|
26 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
27 |
+
Proximal Policy Optimization(PPO).
|
28 |
+
For more information about the system design of RL agent, please refer to \
|
29 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
30 |
+
Interface:
|
31 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
32 |
+
"""
|
33 |
+
|
34 |
+
supported_env_list = [
|
35 |
+
# common
|
36 |
+
'LunarLander-v2',
|
37 |
+
'LunarLanderContinuous-v2',
|
38 |
+
'BipedalWalker-v3',
|
39 |
+
'Pendulum-v1',
|
40 |
+
'acrobot',
|
41 |
+
# ch2: action
|
42 |
+
'rocket_landing',
|
43 |
+
'drone_fly',
|
44 |
+
'hybrid_moving',
|
45 |
+
# ch3: obs
|
46 |
+
'evogym_carrier',
|
47 |
+
'mario',
|
48 |
+
'di_sheep',
|
49 |
+
'procgen_bigfish',
|
50 |
+
# ch4: reward
|
51 |
+
'minigrid_fourroom',
|
52 |
+
'metadrive',
|
53 |
+
# atari
|
54 |
+
'BowlingNoFrameskip-v4',
|
55 |
+
'BreakoutNoFrameskip-v4',
|
56 |
+
'GopherNoFrameskip-v4'
|
57 |
+
'KangarooNoFrameskip-v4',
|
58 |
+
'PongNoFrameskip-v4',
|
59 |
+
'QbertNoFrameskip-v4',
|
60 |
+
'SpaceInvadersNoFrameskip-v4',
|
61 |
+
# mujoco
|
62 |
+
'Hopper-v3',
|
63 |
+
'HalfCheetah-v3',
|
64 |
+
'Walker2d-v3',
|
65 |
+
]
|
66 |
+
"""
|
67 |
+
Overview:
|
68 |
+
List of supported envs.
|
69 |
+
Examples:
|
70 |
+
>>> from ding.bonus.ppof import PPOF
|
71 |
+
>>> print(PPOF.supported_env_list)
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
env_id: str = None,
|
77 |
+
env: BaseEnv = None,
|
78 |
+
seed: int = 0,
|
79 |
+
exp_name: str = None,
|
80 |
+
model: Optional[torch.nn.Module] = None,
|
81 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
82 |
+
policy_state_dict: str = None
|
83 |
+
) -> None:
|
84 |
+
"""
|
85 |
+
Overview:
|
86 |
+
Initialize agent for PPO algorithm.
|
87 |
+
Arguments:
|
88 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
89 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg`` must be specified. \
|
90 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg`` will be ignored. \
|
91 |
+
``env_id`` should be one of the supported envs, which can be found in ``PPOF.supported_env_list``.
|
92 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
93 |
+
If ``env`` is not specified, ``env_id`` or ``cfg.env_id`` must be specified. \
|
94 |
+
``env_id`` or ``cfg.env_id`` will be used to create environment instance. \
|
95 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env_id`` will be ignored.
|
96 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
97 |
+
Default to 0.
|
98 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
99 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
100 |
+
- model (:obj:`torch.nn.Module`): The model of PPO algorithm, which should be an instance of class \
|
101 |
+
``ding.model.PPOFModel``. \
|
102 |
+
If not specified, a default model will be generated according to the configuration.
|
103 |
+
- cfg (:obj:`Union[EasyDict, dict]`): The configuration of PPO algorithm, which is a dict. \
|
104 |
+
Default to None. If not specified, the default configuration will be used.
|
105 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
106 |
+
If specified, the policy will be loaded from this file. Default to None.
|
107 |
+
|
108 |
+
.. note::
|
109 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
110 |
+
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
|
111 |
+
and we want to train an agent with PPO algorithm with default configuration. \
|
112 |
+
Then we can initialize the agent in the following ways:
|
113 |
+
>>> agent = PPOF(env_id='LunarLander-v2')
|
114 |
+
or, if we want can specify the env_id in the configuration:
|
115 |
+
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
|
116 |
+
>>> agent = PPOF(cfg=cfg)
|
117 |
+
There are also other arguments to specify the agent when initializing.
|
118 |
+
For example, if we want to specify the environment instance:
|
119 |
+
>>> env = CustomizedEnv('LunarLander-v2')
|
120 |
+
>>> agent = PPOF(cfg=cfg, env=env)
|
121 |
+
or, if we want to specify the model:
|
122 |
+
>>> model = VAC(**cfg.policy.model)
|
123 |
+
>>> agent = PPOF(cfg=cfg, model=model)
|
124 |
+
or, if we want to reload the policy from a saved policy state dict:
|
125 |
+
>>> agent = PPOF(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
|
126 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
127 |
+
"""
|
128 |
+
|
129 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
130 |
+
|
131 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
132 |
+
cfg = EasyDict(cfg)
|
133 |
+
|
134 |
+
if env_id is not None:
|
135 |
+
assert env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(PPOF.supported_env_list)
|
136 |
+
if cfg is None:
|
137 |
+
cfg = get_instance_config(env_id, algorithm="PPOF")
|
138 |
+
|
139 |
+
if not hasattr(cfg, "env_id"):
|
140 |
+
cfg.env_id = env_id
|
141 |
+
assert cfg.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
142 |
+
else:
|
143 |
+
assert hasattr(cfg, "env_id"), "Please specify env_id in cfg."
|
144 |
+
assert cfg.env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(
|
145 |
+
PPOF.supported_env_list
|
146 |
+
)
|
147 |
+
|
148 |
+
if exp_name is not None:
|
149 |
+
cfg.exp_name = exp_name
|
150 |
+
elif not hasattr(cfg, "exp_name"):
|
151 |
+
cfg.exp_name = "{}-{}".format(cfg.env_id, "PPO")
|
152 |
+
self.cfg = cfg
|
153 |
+
self.exp_name = self.cfg.exp_name
|
154 |
+
|
155 |
+
if env is None:
|
156 |
+
self.env = get_instance_env(self.cfg.env_id)
|
157 |
+
else:
|
158 |
+
self.env = env
|
159 |
+
|
160 |
+
logging.getLogger().setLevel(logging.INFO)
|
161 |
+
self.seed = seed
|
162 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.cuda)
|
163 |
+
|
164 |
+
if not os.path.exists(self.exp_name):
|
165 |
+
os.makedirs(self.exp_name)
|
166 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
167 |
+
|
168 |
+
action_space = self.env.action_space
|
169 |
+
if isinstance(action_space, (gym.spaces.Discrete, gymnasium.spaces.Discrete)):
|
170 |
+
action_shape = int(action_space.n)
|
171 |
+
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
|
172 |
+
action_shape = get_hybrid_shape(action_space)
|
173 |
+
else:
|
174 |
+
action_shape = action_space.shape
|
175 |
+
|
176 |
+
# Three types of value normalization is supported currently
|
177 |
+
assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline']
|
178 |
+
if model is None:
|
179 |
+
if self.cfg.value_norm != 'popart':
|
180 |
+
model = PPOFModel(
|
181 |
+
self.env.observation_space.shape,
|
182 |
+
action_shape,
|
183 |
+
action_space=self.cfg.action_space,
|
184 |
+
**self.cfg.model
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
model = PPOFModel(
|
188 |
+
self.env.observation_space.shape,
|
189 |
+
action_shape,
|
190 |
+
action_space=self.cfg.action_space,
|
191 |
+
popart_head=True,
|
192 |
+
**self.cfg.model
|
193 |
+
)
|
194 |
+
self.policy = PPOFPolicy(self.cfg, model=model)
|
195 |
+
if policy_state_dict is not None:
|
196 |
+
self.policy.load_state_dict(policy_state_dict)
|
197 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
198 |
+
|
199 |
+
def train(
|
200 |
+
self,
|
201 |
+
step: int = int(1e7),
|
202 |
+
collector_env_num: int = 4,
|
203 |
+
evaluator_env_num: int = 4,
|
204 |
+
n_iter_log_show: int = 500,
|
205 |
+
n_iter_save_ckpt: int = 1000,
|
206 |
+
context: Optional[str] = None,
|
207 |
+
reward_model: Optional[str] = None,
|
208 |
+
debug: bool = False,
|
209 |
+
wandb_sweep: bool = False,
|
210 |
+
) -> TrainingReturn:
|
211 |
+
"""
|
212 |
+
Overview:
|
213 |
+
Train the agent with PPO algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
214 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
215 |
+
recorded and saved by wandb.
|
216 |
+
Arguments:
|
217 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
218 |
+
- collector_env_num (:obj:`int`): The number of collector environments. Default to 4.
|
219 |
+
- evaluator_env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
220 |
+
- n_iter_log_show (:obj:`int`): The frequency of logging every training iteration. Default to 500.
|
221 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
222 |
+
Default to 1000.
|
223 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
224 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
225 |
+
- reward_model (:obj:`str`): The reward model name. Default to None. This argument is not supported yet.
|
226 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
227 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
228 |
+
subprocess environment manager will be used.
|
229 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
230 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
231 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
232 |
+
Returns:
|
233 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
234 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
235 |
+
"""
|
236 |
+
|
237 |
+
if debug:
|
238 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
239 |
+
logging.debug(self.policy._model)
|
240 |
+
# define env and policy
|
241 |
+
collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
|
242 |
+
evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
|
243 |
+
|
244 |
+
if reward_model is not None:
|
245 |
+
# self.reward_model = create_reward_model(reward_model, self.cfg.reward_model)
|
246 |
+
pass
|
247 |
+
|
248 |
+
with task.start(ctx=OnlineRLContext()):
|
249 |
+
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
|
250 |
+
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
251 |
+
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
|
252 |
+
task.use(ppof_adv_estimator(self.policy))
|
253 |
+
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
|
254 |
+
task.use(
|
255 |
+
wandb_online_logger(
|
256 |
+
metric_list=self.policy.monitor_vars(),
|
257 |
+
model=self.policy._model,
|
258 |
+
anonymous=True,
|
259 |
+
project_name=self.exp_name,
|
260 |
+
wandb_sweep=wandb_sweep,
|
261 |
+
)
|
262 |
+
)
|
263 |
+
task.use(termination_checker(max_env_step=step))
|
264 |
+
task.run()
|
265 |
+
|
266 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
267 |
+
|
268 |
+
def deploy(
|
269 |
+
self,
|
270 |
+
enable_save_replay: bool = False,
|
271 |
+
concatenate_all_replay: bool = False,
|
272 |
+
replay_save_path: str = None,
|
273 |
+
seed: Optional[Union[int, List]] = None,
|
274 |
+
debug: bool = False
|
275 |
+
) -> EvalReturn:
|
276 |
+
"""
|
277 |
+
Overview:
|
278 |
+
Deploy the agent with PPO algorithm by interacting with the environment, during which the replay video \
|
279 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
280 |
+
Arguments:
|
281 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
282 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
283 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
284 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
285 |
+
the replay video of each episode will be saved separately.
|
286 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
287 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
288 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
289 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
290 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
291 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
292 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
293 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
294 |
+
subprocess environment manager will be used.
|
295 |
+
Returns:
|
296 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
297 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
298 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
299 |
+
"""
|
300 |
+
|
301 |
+
if debug:
|
302 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
303 |
+
# define env and policy
|
304 |
+
env = self.env.clone(caller='evaluator')
|
305 |
+
|
306 |
+
if seed is not None and isinstance(seed, int):
|
307 |
+
seeds = [seed]
|
308 |
+
elif seed is not None and isinstance(seed, list):
|
309 |
+
seeds = seed
|
310 |
+
else:
|
311 |
+
seeds = [self.seed]
|
312 |
+
|
313 |
+
returns = []
|
314 |
+
images = []
|
315 |
+
if enable_save_replay:
|
316 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
317 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
318 |
+
else:
|
319 |
+
logging.warning('No video would be generated during the deploy.')
|
320 |
+
if concatenate_all_replay:
|
321 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
322 |
+
concatenate_all_replay = False
|
323 |
+
|
324 |
+
forward_fn = single_env_forward_wrapper_ttorch(self.policy.eval, self.cfg.cuda)
|
325 |
+
|
326 |
+
# reset first to make sure the env is in the initial state
|
327 |
+
# env will be reset again in the main loop
|
328 |
+
env.reset()
|
329 |
+
|
330 |
+
for seed in seeds:
|
331 |
+
env.seed(seed, dynamic_seed=False)
|
332 |
+
return_ = 0.
|
333 |
+
step = 0
|
334 |
+
obs = env.reset()
|
335 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
336 |
+
while True:
|
337 |
+
action = forward_fn(obs)
|
338 |
+
obs, rew, done, info = env.step(action)
|
339 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
340 |
+
return_ += rew
|
341 |
+
step += 1
|
342 |
+
if done:
|
343 |
+
break
|
344 |
+
logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
|
345 |
+
returns.append(return_)
|
346 |
+
|
347 |
+
env.close()
|
348 |
+
|
349 |
+
if concatenate_all_replay:
|
350 |
+
images = np.concatenate(images, axis=0)
|
351 |
+
import imageio
|
352 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
353 |
+
|
354 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
355 |
+
|
356 |
+
def collect_data(
|
357 |
+
self,
|
358 |
+
env_num: int = 8,
|
359 |
+
save_data_path: Optional[str] = None,
|
360 |
+
n_sample: Optional[int] = None,
|
361 |
+
n_episode: Optional[int] = None,
|
362 |
+
context: Optional[str] = None,
|
363 |
+
debug: bool = False
|
364 |
+
) -> None:
|
365 |
+
"""
|
366 |
+
Overview:
|
367 |
+
Collect data with PPO algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
368 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
369 |
+
``exp_name/demo_data``.
|
370 |
+
Arguments:
|
371 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
372 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
373 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
374 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
375 |
+
If not specified, ``n_episode`` must be specified.
|
376 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
377 |
+
If not specified, ``n_sample`` must be specified.
|
378 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
379 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
380 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
381 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
382 |
+
subprocess environment manager will be used.
|
383 |
+
"""
|
384 |
+
|
385 |
+
if debug:
|
386 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
387 |
+
if n_episode is not None:
|
388 |
+
raise NotImplementedError
|
389 |
+
# define env and policy
|
390 |
+
env = self._setup_env_manager(env_num, context, debug, 'collector')
|
391 |
+
if save_data_path is None:
|
392 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
393 |
+
|
394 |
+
# main execution task
|
395 |
+
with task.start(ctx=OnlineRLContext()):
|
396 |
+
task.use(PPOFStepCollector(self.seed, self.policy, env, n_sample))
|
397 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
398 |
+
task.run(max_step=1)
|
399 |
+
logging.info(
|
400 |
+
f'PPOF collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
401 |
+
)
|
402 |
+
|
403 |
+
def batch_evaluate(
|
404 |
+
self,
|
405 |
+
env_num: int = 4,
|
406 |
+
n_evaluator_episode: int = 4,
|
407 |
+
context: Optional[str] = None,
|
408 |
+
debug: bool = False,
|
409 |
+
) -> EvalReturn:
|
410 |
+
"""
|
411 |
+
Overview:
|
412 |
+
Evaluate the agent with PPO algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
413 |
+
environments. The evaluation result will be returned.
|
414 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
415 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
416 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
417 |
+
Arguments:
|
418 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
419 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
420 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
421 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
422 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
423 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
424 |
+
subprocess environment manager will be used.
|
425 |
+
Returns:
|
426 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
427 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
428 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
429 |
+
"""
|
430 |
+
|
431 |
+
if debug:
|
432 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
433 |
+
# define env and policy
|
434 |
+
env = self._setup_env_manager(env_num, context, debug, 'evaluator')
|
435 |
+
|
436 |
+
# reset first to make sure the env is in the initial state
|
437 |
+
# env will be reset again in the main loop
|
438 |
+
env.launch()
|
439 |
+
env.reset()
|
440 |
+
|
441 |
+
# main execution task
|
442 |
+
with task.start(ctx=OnlineRLContext()):
|
443 |
+
task.use(interaction_evaluator_ttorch(
|
444 |
+
self.seed,
|
445 |
+
self.policy,
|
446 |
+
env,
|
447 |
+
n_evaluator_episode,
|
448 |
+
))
|
449 |
+
task.run(max_step=1)
|
450 |
+
|
451 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
452 |
+
|
453 |
+
def _setup_env_manager(
|
454 |
+
self,
|
455 |
+
env_num: int,
|
456 |
+
context: Optional[str] = None,
|
457 |
+
debug: bool = False,
|
458 |
+
caller: str = 'collector'
|
459 |
+
) -> BaseEnvManagerV2:
|
460 |
+
"""
|
461 |
+
Overview:
|
462 |
+
Setup the environment manager. The environment manager is used to manage multiple environments.
|
463 |
+
Arguments:
|
464 |
+
- env_num (:obj:`int`): The number of environments.
|
465 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
466 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
467 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
468 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
469 |
+
subprocess environment manager will be used.
|
470 |
+
- caller (:obj:`str`): The caller of the environment manager. Default to 'collector'.
|
471 |
+
Returns:
|
472 |
+
- (:obj:`BaseEnvManagerV2`): The environment manager.
|
473 |
+
"""
|
474 |
+
assert caller in ['evaluator', 'collector']
|
475 |
+
if debug:
|
476 |
+
env_cls = BaseEnvManagerV2
|
477 |
+
manager_cfg = env_cls.default_config()
|
478 |
+
else:
|
479 |
+
env_cls = SubprocessEnvManagerV2
|
480 |
+
manager_cfg = env_cls.default_config()
|
481 |
+
if context is not None:
|
482 |
+
manager_cfg.context = context
|
483 |
+
return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg)
|
484 |
+
|
485 |
+
@property
|
486 |
+
def best(self) -> 'PPOF':
|
487 |
+
"""
|
488 |
+
Overview:
|
489 |
+
Load the best model from the checkpoint directory, \
|
490 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
491 |
+
The return value is the agent with the best model.
|
492 |
+
Returns:
|
493 |
+
- (:obj:`PPOF`): The agent with the best model.
|
494 |
+
Examples:
|
495 |
+
>>> agent = PPOF(env_id='LunarLander-v2')
|
496 |
+
>>> agent.train()
|
497 |
+
>>> agent = agent.best()
|
498 |
+
|
499 |
+
.. note::
|
500 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
501 |
+
model will be replaced by the best model.
|
502 |
+
"""
|
503 |
+
|
504 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
505 |
+
# Load best model if it exists
|
506 |
+
if os.path.exists(best_model_file_path):
|
507 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
508 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
509 |
+
return self
|
DI-engine/ding/bonus/sac.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
|
11 |
+
OffPolicyLearner, final_ctx_saver
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import SACPolicy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import ContinuousQAC
|
19 |
+
from ding.model import model_wrap
|
20 |
+
from ding.data import DequeBuffer
|
21 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
22 |
+
from ding.config.example.SAC import supported_env_cfg
|
23 |
+
from ding.config.example.SAC import supported_env
|
24 |
+
|
25 |
+
|
26 |
+
class SACAgent:
|
27 |
+
"""
|
28 |
+
Overview:
|
29 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
30 |
+
Soft Actor-Critic(SAC).
|
31 |
+
For more information about the system design of RL agent, please refer to \
|
32 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
33 |
+
Interface:
|
34 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
35 |
+
"""
|
36 |
+
supported_env_list = list(supported_env_cfg.keys())
|
37 |
+
"""
|
38 |
+
Overview:
|
39 |
+
List of supported envs.
|
40 |
+
Examples:
|
41 |
+
>>> from ding.bonus.sac import SACAgent
|
42 |
+
>>> print(SACAgent.supported_env_list)
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
env_id: str = None,
|
48 |
+
env: BaseEnv = None,
|
49 |
+
seed: int = 0,
|
50 |
+
exp_name: str = None,
|
51 |
+
model: Optional[torch.nn.Module] = None,
|
52 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
53 |
+
policy_state_dict: str = None,
|
54 |
+
) -> None:
|
55 |
+
"""
|
56 |
+
Overview:
|
57 |
+
Initialize agent for SAC algorithm.
|
58 |
+
Arguments:
|
59 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
60 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
61 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
62 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
63 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
64 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
65 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
66 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
67 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
68 |
+
Default to 0.
|
69 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
70 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
71 |
+
- model (:obj:`torch.nn.Module`): The model of SAC algorithm, which should be an instance of class \
|
72 |
+
:class:`ding.model.ContinuousQAC`. \
|
73 |
+
If not specified, a default model will be generated according to the configuration.
|
74 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of SAC algorithm, which is a dict. \
|
75 |
+
Default to None. If not specified, the default configuration will be used. \
|
76 |
+
The default configuration can be found in ``ding/config/example/SAC/gym_lunarlander_v2.py``.
|
77 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
78 |
+
If specified, the policy will be loaded from this file. Default to None.
|
79 |
+
|
80 |
+
.. note::
|
81 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
82 |
+
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
|
83 |
+
and we want to train an agent with SAC algorithm with default configuration. \
|
84 |
+
Then we can initialize the agent in the following ways:
|
85 |
+
>>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
|
86 |
+
or, if we want can specify the env_id in the configuration:
|
87 |
+
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
|
88 |
+
>>> agent = SACAgent(cfg=cfg)
|
89 |
+
There are also other arguments to specify the agent when initializing.
|
90 |
+
For example, if we want to specify the environment instance:
|
91 |
+
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
|
92 |
+
>>> agent = SACAgent(cfg=cfg, env=env)
|
93 |
+
or, if we want to specify the model:
|
94 |
+
>>> model = ContinuousQAC(**cfg.policy.model)
|
95 |
+
>>> agent = SACAgent(cfg=cfg, model=model)
|
96 |
+
or, if we want to reload the policy from a saved policy state dict:
|
97 |
+
>>> agent = SACAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
|
98 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
99 |
+
"""
|
100 |
+
|
101 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
102 |
+
|
103 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
104 |
+
cfg = EasyDict(cfg)
|
105 |
+
|
106 |
+
if env_id is not None:
|
107 |
+
assert env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
|
108 |
+
SACAgent.supported_env_list
|
109 |
+
)
|
110 |
+
if cfg is None:
|
111 |
+
cfg = supported_env_cfg[env_id]
|
112 |
+
else:
|
113 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
114 |
+
else:
|
115 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
116 |
+
assert cfg.env.env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
|
117 |
+
SACAgent.supported_env_list
|
118 |
+
)
|
119 |
+
default_policy_config = EasyDict({"policy": SACPolicy.default_config()})
|
120 |
+
default_policy_config.update(cfg)
|
121 |
+
cfg = default_policy_config
|
122 |
+
|
123 |
+
if exp_name is not None:
|
124 |
+
cfg.exp_name = exp_name
|
125 |
+
self.cfg = compile_config(cfg, policy=SACPolicy)
|
126 |
+
self.exp_name = self.cfg.exp_name
|
127 |
+
if env is None:
|
128 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
129 |
+
else:
|
130 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
131 |
+
self.env = env
|
132 |
+
|
133 |
+
logging.getLogger().setLevel(logging.INFO)
|
134 |
+
self.seed = seed
|
135 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
136 |
+
if not os.path.exists(self.exp_name):
|
137 |
+
os.makedirs(self.exp_name)
|
138 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
139 |
+
if model is None:
|
140 |
+
model = ContinuousQAC(**self.cfg.policy.model)
|
141 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
142 |
+
self.policy = SACPolicy(self.cfg.policy, model=model)
|
143 |
+
if policy_state_dict is not None:
|
144 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
145 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
146 |
+
|
147 |
+
def train(
|
148 |
+
self,
|
149 |
+
step: int = int(1e7),
|
150 |
+
collector_env_num: int = None,
|
151 |
+
evaluator_env_num: int = None,
|
152 |
+
n_iter_save_ckpt: int = 1000,
|
153 |
+
context: Optional[str] = None,
|
154 |
+
debug: bool = False,
|
155 |
+
wandb_sweep: bool = False,
|
156 |
+
) -> TrainingReturn:
|
157 |
+
"""
|
158 |
+
Overview:
|
159 |
+
Train the agent with SAC algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
160 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
161 |
+
recorded and saved by wandb.
|
162 |
+
Arguments:
|
163 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
164 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
165 |
+
If not specified, it will be set according to the configuration.
|
166 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
167 |
+
If not specified, it will be set according to the configuration.
|
168 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
169 |
+
Default to 1000.
|
170 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
171 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
172 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
173 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
174 |
+
subprocess environment manager will be used.
|
175 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
176 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
177 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
178 |
+
Returns:
|
179 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
180 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
181 |
+
"""
|
182 |
+
|
183 |
+
if debug:
|
184 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
185 |
+
logging.debug(self.policy._model)
|
186 |
+
# define env and policy
|
187 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
188 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
189 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
190 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
191 |
+
|
192 |
+
with task.start(ctx=OnlineRLContext()):
|
193 |
+
task.use(
|
194 |
+
interaction_evaluator(
|
195 |
+
self.cfg,
|
196 |
+
self.policy.eval_mode,
|
197 |
+
evaluator_env,
|
198 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
199 |
+
)
|
200 |
+
)
|
201 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
202 |
+
task.use(
|
203 |
+
StepCollector(
|
204 |
+
self.cfg,
|
205 |
+
self.policy.collect_mode,
|
206 |
+
collector_env,
|
207 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
208 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
209 |
+
)
|
210 |
+
)
|
211 |
+
task.use(data_pusher(self.cfg, self.buffer_))
|
212 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
213 |
+
task.use(
|
214 |
+
wandb_online_logger(
|
215 |
+
metric_list=self.policy._monitor_vars_learn(),
|
216 |
+
model=self.policy._model,
|
217 |
+
anonymous=True,
|
218 |
+
project_name=self.exp_name,
|
219 |
+
wandb_sweep=wandb_sweep,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
task.use(termination_checker(max_env_step=step))
|
223 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
224 |
+
task.run()
|
225 |
+
|
226 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
227 |
+
|
228 |
+
def deploy(
|
229 |
+
self,
|
230 |
+
enable_save_replay: bool = False,
|
231 |
+
concatenate_all_replay: bool = False,
|
232 |
+
replay_save_path: str = None,
|
233 |
+
seed: Optional[Union[int, List]] = None,
|
234 |
+
debug: bool = False
|
235 |
+
) -> EvalReturn:
|
236 |
+
"""
|
237 |
+
Overview:
|
238 |
+
Deploy the agent with SAC algorithm by interacting with the environment, during which the replay video \
|
239 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
240 |
+
Arguments:
|
241 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
242 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
243 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
244 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
245 |
+
the replay video of each episode will be saved separately.
|
246 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
247 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
248 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
249 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
250 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
251 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
252 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
253 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
254 |
+
subprocess environment manager will be used.
|
255 |
+
Returns:
|
256 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
257 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
258 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
259 |
+
"""
|
260 |
+
|
261 |
+
if debug:
|
262 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
263 |
+
# define env and policy
|
264 |
+
env = self.env.clone(caller='evaluator')
|
265 |
+
|
266 |
+
if seed is not None and isinstance(seed, int):
|
267 |
+
seeds = [seed]
|
268 |
+
elif seed is not None and isinstance(seed, list):
|
269 |
+
seeds = seed
|
270 |
+
else:
|
271 |
+
seeds = [self.seed]
|
272 |
+
|
273 |
+
returns = []
|
274 |
+
images = []
|
275 |
+
if enable_save_replay:
|
276 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
277 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
278 |
+
else:
|
279 |
+
logging.warning('No video would be generated during the deploy.')
|
280 |
+
if concatenate_all_replay:
|
281 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
282 |
+
concatenate_all_replay = False
|
283 |
+
|
284 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
285 |
+
|
286 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
|
287 |
+
|
288 |
+
def _forward(obs):
|
289 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
290 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
291 |
+
if cuda and torch.cuda.is_available():
|
292 |
+
obs = obs.cuda()
|
293 |
+
(mu, sigma) = forward_fn(obs, mode='compute_actor')['logit']
|
294 |
+
action = torch.tanh(mu).detach().cpu().numpy()[0] # deterministic_eval
|
295 |
+
return action
|
296 |
+
|
297 |
+
return _forward
|
298 |
+
|
299 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
300 |
+
|
301 |
+
# reset first to make sure the env is in the initial state
|
302 |
+
# env will be reset again in the main loop
|
303 |
+
env.reset()
|
304 |
+
|
305 |
+
for seed in seeds:
|
306 |
+
env.seed(seed, dynamic_seed=False)
|
307 |
+
return_ = 0.
|
308 |
+
step = 0
|
309 |
+
obs = env.reset()
|
310 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
311 |
+
while True:
|
312 |
+
action = forward_fn(obs)
|
313 |
+
obs, rew, done, info = env.step(action)
|
314 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
315 |
+
return_ += rew
|
316 |
+
step += 1
|
317 |
+
if done:
|
318 |
+
break
|
319 |
+
logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
|
320 |
+
returns.append(return_)
|
321 |
+
|
322 |
+
env.close()
|
323 |
+
|
324 |
+
if concatenate_all_replay:
|
325 |
+
images = np.concatenate(images, axis=0)
|
326 |
+
import imageio
|
327 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
328 |
+
|
329 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
330 |
+
|
331 |
+
def collect_data(
|
332 |
+
self,
|
333 |
+
env_num: int = 8,
|
334 |
+
save_data_path: Optional[str] = None,
|
335 |
+
n_sample: Optional[int] = None,
|
336 |
+
n_episode: Optional[int] = None,
|
337 |
+
context: Optional[str] = None,
|
338 |
+
debug: bool = False
|
339 |
+
) -> None:
|
340 |
+
"""
|
341 |
+
Overview:
|
342 |
+
Collect data with SAC algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
343 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
344 |
+
``exp_name/demo_data``.
|
345 |
+
Arguments:
|
346 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
347 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
348 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
349 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
350 |
+
If not specified, ``n_episode`` must be specified.
|
351 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
352 |
+
If not specified, ``n_sample`` must be specified.
|
353 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
354 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
355 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
356 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
357 |
+
subprocess environment manager will be used.
|
358 |
+
"""
|
359 |
+
|
360 |
+
if debug:
|
361 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
362 |
+
if n_episode is not None:
|
363 |
+
raise NotImplementedError
|
364 |
+
# define env and policy
|
365 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
366 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
367 |
+
|
368 |
+
if save_data_path is None:
|
369 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
370 |
+
|
371 |
+
# main execution task
|
372 |
+
with task.start(ctx=OnlineRLContext()):
|
373 |
+
task.use(
|
374 |
+
StepCollector(
|
375 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
376 |
+
)
|
377 |
+
)
|
378 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
379 |
+
task.run(max_step=1)
|
380 |
+
logging.info(
|
381 |
+
f'SAC collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
382 |
+
)
|
383 |
+
|
384 |
+
def batch_evaluate(
|
385 |
+
self,
|
386 |
+
env_num: int = 4,
|
387 |
+
n_evaluator_episode: int = 4,
|
388 |
+
context: Optional[str] = None,
|
389 |
+
debug: bool = False
|
390 |
+
) -> EvalReturn:
|
391 |
+
"""
|
392 |
+
Overview:
|
393 |
+
Evaluate the agent with SAC algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
394 |
+
environments. The evaluation result will be returned.
|
395 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
396 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
397 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
398 |
+
Arguments:
|
399 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
400 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
401 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
402 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
403 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
404 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
405 |
+
subprocess environment manager will be used.
|
406 |
+
Returns:
|
407 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
408 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
409 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
410 |
+
"""
|
411 |
+
|
412 |
+
if debug:
|
413 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
414 |
+
# define env and policy
|
415 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
416 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
417 |
+
|
418 |
+
# reset first to make sure the env is in the initial state
|
419 |
+
# env will be reset again in the main loop
|
420 |
+
env.launch()
|
421 |
+
env.reset()
|
422 |
+
|
423 |
+
evaluate_cfg = self.cfg
|
424 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
425 |
+
|
426 |
+
# main execution task
|
427 |
+
with task.start(ctx=OnlineRLContext()):
|
428 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
429 |
+
task.run(max_step=1)
|
430 |
+
|
431 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
432 |
+
|
433 |
+
@property
|
434 |
+
def best(self) -> 'SACAgent':
|
435 |
+
"""
|
436 |
+
Overview:
|
437 |
+
Load the best model from the checkpoint directory, \
|
438 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
439 |
+
The return value is the agent with the best model.
|
440 |
+
Returns:
|
441 |
+
- (:obj:`SACAgent`): The agent with the best model.
|
442 |
+
Examples:
|
443 |
+
>>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
|
444 |
+
>>> agent.train()
|
445 |
+
>>> agent = agent.best
|
446 |
+
|
447 |
+
.. note::
|
448 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
449 |
+
model will be replaced by the best model.
|
450 |
+
"""
|
451 |
+
|
452 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
453 |
+
# Load best model if it exists
|
454 |
+
if os.path.exists(best_model_file_path):
|
455 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
456 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
457 |
+
return self
|
DI-engine/ding/bonus/sql.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
|
11 |
+
OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import SQLPolicy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import DQN
|
19 |
+
from ding.model import model_wrap
|
20 |
+
from ding.data import DequeBuffer
|
21 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
22 |
+
from ding.config.example.SQL import supported_env_cfg
|
23 |
+
from ding.config.example.SQL import supported_env
|
24 |
+
|
25 |
+
|
26 |
+
class SQLAgent:
|
27 |
+
"""
|
28 |
+
Overview:
|
29 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
30 |
+
Soft Q-Learning(SQL).
|
31 |
+
For more information about the system design of RL agent, please refer to \
|
32 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
33 |
+
Interface:
|
34 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
35 |
+
"""
|
36 |
+
supported_env_list = list(supported_env_cfg.keys())
|
37 |
+
"""
|
38 |
+
Overview:
|
39 |
+
List of supported envs.
|
40 |
+
Examples:
|
41 |
+
>>> from ding.bonus.sql import SQLAgent
|
42 |
+
>>> print(SQLAgent.supported_env_list)
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
env_id: str = None,
|
48 |
+
env: BaseEnv = None,
|
49 |
+
seed: int = 0,
|
50 |
+
exp_name: str = None,
|
51 |
+
model: Optional[torch.nn.Module] = None,
|
52 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
53 |
+
policy_state_dict: str = None,
|
54 |
+
) -> None:
|
55 |
+
"""
|
56 |
+
Overview:
|
57 |
+
Initialize agent for SQL algorithm.
|
58 |
+
Arguments:
|
59 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
60 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
61 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
62 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
63 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
64 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
65 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
66 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
67 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
68 |
+
Default to 0.
|
69 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
70 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
71 |
+
- model (:obj:`torch.nn.Module`): The model of SQL algorithm, which should be an instance of class \
|
72 |
+
:class:`ding.model.DQN`. \
|
73 |
+
If not specified, a default model will be generated according to the configuration.
|
74 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of SQL algorithm, which is a dict. \
|
75 |
+
Default to None. If not specified, the default configuration will be used. \
|
76 |
+
The default configuration can be found in ``ding/config/example/SQL/gym_lunarlander_v2.py``.
|
77 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
78 |
+
If specified, the policy will be loaded from this file. Default to None.
|
79 |
+
|
80 |
+
.. note::
|
81 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
82 |
+
For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
|
83 |
+
and we want to train an agent with SQL algorithm with default configuration. \
|
84 |
+
Then we can initialize the agent in the following ways:
|
85 |
+
>>> agent = SQLAgent(env_id='LunarLander-v2')
|
86 |
+
or, if we want can specify the env_id in the configuration:
|
87 |
+
>>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
|
88 |
+
>>> agent = SQLAgent(cfg=cfg)
|
89 |
+
There are also other arguments to specify the agent when initializing.
|
90 |
+
For example, if we want to specify the environment instance:
|
91 |
+
>>> env = CustomizedEnv('LunarLander-v2')
|
92 |
+
>>> agent = SQLAgent(cfg=cfg, env=env)
|
93 |
+
or, if we want to specify the model:
|
94 |
+
>>> model = DQN(**cfg.policy.model)
|
95 |
+
>>> agent = SQLAgent(cfg=cfg, model=model)
|
96 |
+
or, if we want to reload the policy from a saved policy state dict:
|
97 |
+
>>> agent = SQLAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
|
98 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
99 |
+
"""
|
100 |
+
|
101 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
102 |
+
|
103 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
104 |
+
cfg = EasyDict(cfg)
|
105 |
+
|
106 |
+
if env_id is not None:
|
107 |
+
assert env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
|
108 |
+
SQLAgent.supported_env_list
|
109 |
+
)
|
110 |
+
if cfg is None:
|
111 |
+
cfg = supported_env_cfg[env_id]
|
112 |
+
else:
|
113 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
114 |
+
else:
|
115 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
116 |
+
assert cfg.env.env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
|
117 |
+
SQLAgent.supported_env_list
|
118 |
+
)
|
119 |
+
default_policy_config = EasyDict({"policy": SQLPolicy.default_config()})
|
120 |
+
default_policy_config.update(cfg)
|
121 |
+
cfg = default_policy_config
|
122 |
+
|
123 |
+
if exp_name is not None:
|
124 |
+
cfg.exp_name = exp_name
|
125 |
+
self.cfg = compile_config(cfg, policy=SQLPolicy)
|
126 |
+
self.exp_name = self.cfg.exp_name
|
127 |
+
if env is None:
|
128 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
129 |
+
else:
|
130 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
131 |
+
self.env = env
|
132 |
+
|
133 |
+
logging.getLogger().setLevel(logging.INFO)
|
134 |
+
self.seed = seed
|
135 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
136 |
+
if not os.path.exists(self.exp_name):
|
137 |
+
os.makedirs(self.exp_name)
|
138 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
139 |
+
if model is None:
|
140 |
+
model = DQN(**self.cfg.policy.model)
|
141 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
142 |
+
self.policy = SQLPolicy(self.cfg.policy, model=model)
|
143 |
+
if policy_state_dict is not None:
|
144 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
145 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
146 |
+
|
147 |
+
def train(
|
148 |
+
self,
|
149 |
+
step: int = int(1e7),
|
150 |
+
collector_env_num: int = None,
|
151 |
+
evaluator_env_num: int = None,
|
152 |
+
n_iter_save_ckpt: int = 1000,
|
153 |
+
context: Optional[str] = None,
|
154 |
+
debug: bool = False,
|
155 |
+
wandb_sweep: bool = False,
|
156 |
+
) -> TrainingReturn:
|
157 |
+
"""
|
158 |
+
Overview:
|
159 |
+
Train the agent with SQL algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
160 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
161 |
+
recorded and saved by wandb.
|
162 |
+
Arguments:
|
163 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
164 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
165 |
+
If not specified, it will be set according to the configuration.
|
166 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
167 |
+
If not specified, it will be set according to the configuration.
|
168 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
169 |
+
Default to 1000.
|
170 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
171 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
172 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
173 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
174 |
+
subprocess environment manager will be used.
|
175 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
176 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
177 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
178 |
+
Returns:
|
179 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
180 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
181 |
+
"""
|
182 |
+
|
183 |
+
if debug:
|
184 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
185 |
+
logging.debug(self.policy._model)
|
186 |
+
# define env and policy
|
187 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
188 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
189 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
190 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
191 |
+
|
192 |
+
with task.start(ctx=OnlineRLContext()):
|
193 |
+
task.use(
|
194 |
+
interaction_evaluator(
|
195 |
+
self.cfg,
|
196 |
+
self.policy.eval_mode,
|
197 |
+
evaluator_env,
|
198 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
199 |
+
)
|
200 |
+
)
|
201 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
202 |
+
task.use(eps_greedy_handler(self.cfg))
|
203 |
+
task.use(
|
204 |
+
StepCollector(
|
205 |
+
self.cfg,
|
206 |
+
self.policy.collect_mode,
|
207 |
+
collector_env,
|
208 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
209 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
210 |
+
)
|
211 |
+
)
|
212 |
+
if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
|
213 |
+
task.use(nstep_reward_enhancer(self.cfg))
|
214 |
+
task.use(data_pusher(self.cfg, self.buffer_))
|
215 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
216 |
+
task.use(
|
217 |
+
wandb_online_logger(
|
218 |
+
metric_list=self.policy._monitor_vars_learn(),
|
219 |
+
model=self.policy._model,
|
220 |
+
anonymous=True,
|
221 |
+
project_name=self.exp_name,
|
222 |
+
wandb_sweep=wandb_sweep,
|
223 |
+
)
|
224 |
+
)
|
225 |
+
task.use(termination_checker(max_env_step=step))
|
226 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
227 |
+
task.run()
|
228 |
+
|
229 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
230 |
+
|
231 |
+
def deploy(
|
232 |
+
self,
|
233 |
+
enable_save_replay: bool = False,
|
234 |
+
concatenate_all_replay: bool = False,
|
235 |
+
replay_save_path: str = None,
|
236 |
+
seed: Optional[Union[int, List]] = None,
|
237 |
+
debug: bool = False
|
238 |
+
) -> EvalReturn:
|
239 |
+
"""
|
240 |
+
Overview:
|
241 |
+
Deploy the agent with SQL algorithm by interacting with the environment, during which the replay video \
|
242 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
243 |
+
Arguments:
|
244 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
245 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
246 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
247 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
248 |
+
the replay video of each episode will be saved separately.
|
249 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
250 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
251 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
252 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
253 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
254 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
255 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
256 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
257 |
+
subprocess environment manager will be used.
|
258 |
+
Returns:
|
259 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
260 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
261 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
262 |
+
"""
|
263 |
+
|
264 |
+
if debug:
|
265 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
266 |
+
# define env and policy
|
267 |
+
env = self.env.clone(caller='evaluator')
|
268 |
+
|
269 |
+
if seed is not None and isinstance(seed, int):
|
270 |
+
seeds = [seed]
|
271 |
+
elif seed is not None and isinstance(seed, list):
|
272 |
+
seeds = seed
|
273 |
+
else:
|
274 |
+
seeds = [self.seed]
|
275 |
+
|
276 |
+
returns = []
|
277 |
+
images = []
|
278 |
+
if enable_save_replay:
|
279 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
280 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
281 |
+
else:
|
282 |
+
logging.warning('No video would be generated during the deploy.')
|
283 |
+
if concatenate_all_replay:
|
284 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
285 |
+
concatenate_all_replay = False
|
286 |
+
|
287 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
288 |
+
|
289 |
+
forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
|
290 |
+
|
291 |
+
def _forward(obs):
|
292 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
293 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
294 |
+
if cuda and torch.cuda.is_available():
|
295 |
+
obs = obs.cuda()
|
296 |
+
action = forward_fn(obs)["action"]
|
297 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
298 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
299 |
+
return action
|
300 |
+
|
301 |
+
return _forward
|
302 |
+
|
303 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
304 |
+
|
305 |
+
# reset first to make sure the env is in the initial state
|
306 |
+
# env will be reset again in the main loop
|
307 |
+
env.reset()
|
308 |
+
|
309 |
+
for seed in seeds:
|
310 |
+
env.seed(seed, dynamic_seed=False)
|
311 |
+
return_ = 0.
|
312 |
+
step = 0
|
313 |
+
obs = env.reset()
|
314 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
315 |
+
while True:
|
316 |
+
action = forward_fn(obs)
|
317 |
+
obs, rew, done, info = env.step(action)
|
318 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
319 |
+
return_ += rew
|
320 |
+
step += 1
|
321 |
+
if done:
|
322 |
+
break
|
323 |
+
logging.info(f'SQL deploy is finished, final episode return with {step} steps is: {return_}')
|
324 |
+
returns.append(return_)
|
325 |
+
|
326 |
+
env.close()
|
327 |
+
|
328 |
+
if concatenate_all_replay:
|
329 |
+
images = np.concatenate(images, axis=0)
|
330 |
+
import imageio
|
331 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
332 |
+
|
333 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
334 |
+
|
335 |
+
def collect_data(
|
336 |
+
self,
|
337 |
+
env_num: int = 8,
|
338 |
+
save_data_path: Optional[str] = None,
|
339 |
+
n_sample: Optional[int] = None,
|
340 |
+
n_episode: Optional[int] = None,
|
341 |
+
context: Optional[str] = None,
|
342 |
+
debug: bool = False
|
343 |
+
) -> None:
|
344 |
+
"""
|
345 |
+
Overview:
|
346 |
+
Collect data with SQL algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
347 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
348 |
+
``exp_name/demo_data``.
|
349 |
+
Arguments:
|
350 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
351 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
352 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
353 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
354 |
+
If not specified, ``n_episode`` must be specified.
|
355 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
356 |
+
If not specified, ``n_sample`` must be specified.
|
357 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
358 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
359 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
360 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
361 |
+
subprocess environment manager will be used.
|
362 |
+
"""
|
363 |
+
|
364 |
+
if debug:
|
365 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
366 |
+
if n_episode is not None:
|
367 |
+
raise NotImplementedError
|
368 |
+
# define env and policy
|
369 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
370 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
371 |
+
|
372 |
+
if save_data_path is None:
|
373 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
374 |
+
|
375 |
+
# main execution task
|
376 |
+
with task.start(ctx=OnlineRLContext()):
|
377 |
+
task.use(
|
378 |
+
StepCollector(
|
379 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
380 |
+
)
|
381 |
+
)
|
382 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
383 |
+
task.run(max_step=1)
|
384 |
+
logging.info(
|
385 |
+
f'SQL collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
386 |
+
)
|
387 |
+
|
388 |
+
def batch_evaluate(
|
389 |
+
self,
|
390 |
+
env_num: int = 4,
|
391 |
+
n_evaluator_episode: int = 4,
|
392 |
+
context: Optional[str] = None,
|
393 |
+
debug: bool = False
|
394 |
+
) -> EvalReturn:
|
395 |
+
"""
|
396 |
+
Overview:
|
397 |
+
Evaluate the agent with SQL algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
398 |
+
environments. The evaluation result will be returned.
|
399 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
400 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
401 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
402 |
+
Arguments:
|
403 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
404 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
405 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
406 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
407 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
408 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
409 |
+
subprocess environment manager will be used.
|
410 |
+
Returns:
|
411 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
412 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
413 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
414 |
+
"""
|
415 |
+
|
416 |
+
if debug:
|
417 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
418 |
+
# define env and policy
|
419 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
420 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
421 |
+
|
422 |
+
# reset first to make sure the env is in the initial state
|
423 |
+
# env will be reset again in the main loop
|
424 |
+
env.launch()
|
425 |
+
env.reset()
|
426 |
+
|
427 |
+
evaluate_cfg = self.cfg
|
428 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
429 |
+
|
430 |
+
# main execution task
|
431 |
+
with task.start(ctx=OnlineRLContext()):
|
432 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
433 |
+
task.run(max_step=1)
|
434 |
+
|
435 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
436 |
+
|
437 |
+
@property
|
438 |
+
def best(self) -> 'SQLAgent':
|
439 |
+
"""
|
440 |
+
Overview:
|
441 |
+
Load the best model from the checkpoint directory, \
|
442 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
443 |
+
The return value is the agent with the best model.
|
444 |
+
Returns:
|
445 |
+
- (:obj:`SQLAgent`): The agent with the best model.
|
446 |
+
Examples:
|
447 |
+
>>> agent = SQLAgent(env_id='LunarLander-v2')
|
448 |
+
>>> agent.train()
|
449 |
+
>>> agent = agent.best
|
450 |
+
|
451 |
+
.. note::
|
452 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
453 |
+
model will be replaced by the best model.
|
454 |
+
"""
|
455 |
+
|
456 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
457 |
+
# Load best model if it exists
|
458 |
+
if os.path.exists(best_model_file_path):
|
459 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
460 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
461 |
+
return self
|
DI-engine/ding/bonus/td3.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, List
|
2 |
+
from ditk import logging
|
3 |
+
from easydict import EasyDict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import treetensor.torch as ttorch
|
8 |
+
from ding.framework import task, OnlineRLContext
|
9 |
+
from ding.framework.middleware import CkptSaver, \
|
10 |
+
wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
|
11 |
+
OffPolicyLearner, final_ctx_saver
|
12 |
+
from ding.envs import BaseEnv
|
13 |
+
from ding.envs import setup_ding_env_manager
|
14 |
+
from ding.policy import TD3Policy
|
15 |
+
from ding.utils import set_pkg_seed
|
16 |
+
from ding.utils import get_env_fps, render
|
17 |
+
from ding.config import save_config_py, compile_config
|
18 |
+
from ding.model import ContinuousQAC
|
19 |
+
from ding.data import DequeBuffer
|
20 |
+
from ding.bonus.common import TrainingReturn, EvalReturn
|
21 |
+
from ding.config.example.TD3 import supported_env_cfg
|
22 |
+
from ding.config.example.TD3 import supported_env
|
23 |
+
|
24 |
+
|
25 |
+
class TD3Agent:
|
26 |
+
"""
|
27 |
+
Overview:
|
28 |
+
Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
|
29 |
+
Twin Delayed Deep Deterministic Policy Gradient(TD3).
|
30 |
+
For more information about the system design of RL agent, please refer to \
|
31 |
+
<https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
|
32 |
+
Interface:
|
33 |
+
``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
|
34 |
+
"""
|
35 |
+
supported_env_list = list(supported_env_cfg.keys())
|
36 |
+
"""
|
37 |
+
Overview:
|
38 |
+
List of supported envs.
|
39 |
+
Examples:
|
40 |
+
>>> from ding.bonus.td3 import TD3Agent
|
41 |
+
>>> print(TD3Agent.supported_env_list)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
env_id: str = None,
|
47 |
+
env: BaseEnv = None,
|
48 |
+
seed: int = 0,
|
49 |
+
exp_name: str = None,
|
50 |
+
model: Optional[torch.nn.Module] = None,
|
51 |
+
cfg: Optional[Union[EasyDict, dict]] = None,
|
52 |
+
policy_state_dict: str = None,
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Overview:
|
56 |
+
Initialize agent for TD3 algorithm.
|
57 |
+
Arguments:
|
58 |
+
- env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
|
59 |
+
If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
|
60 |
+
If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
|
61 |
+
``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
|
62 |
+
- env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
|
63 |
+
If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
|
64 |
+
``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
|
65 |
+
If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
|
66 |
+
- seed (:obj:`int`): The random seed, which is set before running the program. \
|
67 |
+
Default to 0.
|
68 |
+
- exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
|
69 |
+
log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
|
70 |
+
- model (:obj:`torch.nn.Module`): The model of TD3 algorithm, which should be an instance of class \
|
71 |
+
:class:`ding.model.ContinuousQAC`. \
|
72 |
+
If not specified, a default model will be generated according to the configuration.
|
73 |
+
- cfg (:obj:Union[EasyDict, dict]): The configuration of TD3 algorithm, which is a dict. \
|
74 |
+
Default to None. If not specified, the default configuration will be used. \
|
75 |
+
The default configuration can be found in ``ding/config/example/TD3/gym_lunarlander_v2.py``.
|
76 |
+
- policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
|
77 |
+
If specified, the policy will be loaded from this file. Default to None.
|
78 |
+
|
79 |
+
.. note::
|
80 |
+
An RL Agent Instance can be initialized in two basic ways. \
|
81 |
+
For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
|
82 |
+
and we want to train an agent with TD3 algorithm with default configuration. \
|
83 |
+
Then we can initialize the agent in the following ways:
|
84 |
+
>>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
|
85 |
+
or, if we want can specify the env_id in the configuration:
|
86 |
+
>>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
|
87 |
+
>>> agent = TD3Agent(cfg=cfg)
|
88 |
+
There are also other arguments to specify the agent when initializing.
|
89 |
+
For example, if we want to specify the environment instance:
|
90 |
+
>>> env = CustomizedEnv('LunarLanderContinuous-v2')
|
91 |
+
>>> agent = TD3Agent(cfg=cfg, env=env)
|
92 |
+
or, if we want to specify the model:
|
93 |
+
>>> model = ContinuousQAC(**cfg.policy.model)
|
94 |
+
>>> agent = TD3Agent(cfg=cfg, model=model)
|
95 |
+
or, if we want to reload the policy from a saved policy state dict:
|
96 |
+
>>> agent = TD3Agent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
|
97 |
+
Make sure that the configuration is consistent with the saved policy state dict.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
|
101 |
+
|
102 |
+
if cfg is not None and not isinstance(cfg, EasyDict):
|
103 |
+
cfg = EasyDict(cfg)
|
104 |
+
|
105 |
+
if env_id is not None:
|
106 |
+
assert env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
|
107 |
+
TD3Agent.supported_env_list
|
108 |
+
)
|
109 |
+
if cfg is None:
|
110 |
+
cfg = supported_env_cfg[env_id]
|
111 |
+
else:
|
112 |
+
assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
|
113 |
+
else:
|
114 |
+
assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
|
115 |
+
assert cfg.env.env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
|
116 |
+
TD3Agent.supported_env_list
|
117 |
+
)
|
118 |
+
default_policy_config = EasyDict({"policy": TD3Policy.default_config()})
|
119 |
+
default_policy_config.update(cfg)
|
120 |
+
cfg = default_policy_config
|
121 |
+
|
122 |
+
if exp_name is not None:
|
123 |
+
cfg.exp_name = exp_name
|
124 |
+
self.cfg = compile_config(cfg, policy=TD3Policy)
|
125 |
+
self.exp_name = self.cfg.exp_name
|
126 |
+
if env is None:
|
127 |
+
self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
|
128 |
+
else:
|
129 |
+
assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
|
130 |
+
self.env = env
|
131 |
+
|
132 |
+
logging.getLogger().setLevel(logging.INFO)
|
133 |
+
self.seed = seed
|
134 |
+
set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
|
135 |
+
if not os.path.exists(self.exp_name):
|
136 |
+
os.makedirs(self.exp_name)
|
137 |
+
save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
|
138 |
+
if model is None:
|
139 |
+
model = ContinuousQAC(**self.cfg.policy.model)
|
140 |
+
self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
|
141 |
+
self.policy = TD3Policy(self.cfg.policy, model=model)
|
142 |
+
if policy_state_dict is not None:
|
143 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
144 |
+
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
|
145 |
+
|
146 |
+
def train(
|
147 |
+
self,
|
148 |
+
step: int = int(1e7),
|
149 |
+
collector_env_num: int = None,
|
150 |
+
evaluator_env_num: int = None,
|
151 |
+
n_iter_save_ckpt: int = 1000,
|
152 |
+
context: Optional[str] = None,
|
153 |
+
debug: bool = False,
|
154 |
+
wandb_sweep: bool = False,
|
155 |
+
) -> TrainingReturn:
|
156 |
+
"""
|
157 |
+
Overview:
|
158 |
+
Train the agent with TD3 algorithm for ``step`` iterations with ``collector_env_num`` collector \
|
159 |
+
environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
|
160 |
+
recorded and saved by wandb.
|
161 |
+
Arguments:
|
162 |
+
- step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
|
163 |
+
- collector_env_num (:obj:`int`): The collector environment number. Default to None. \
|
164 |
+
If not specified, it will be set according to the configuration.
|
165 |
+
- evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
|
166 |
+
If not specified, it will be set according to the configuration.
|
167 |
+
- n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
|
168 |
+
Default to 1000.
|
169 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
170 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
171 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
172 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
173 |
+
subprocess environment manager will be used.
|
174 |
+
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
|
175 |
+
which is a hyper-parameter optimization process for seeking the best configurations. \
|
176 |
+
Default to False. If True, the wandb sweep id will be used as the experiment name.
|
177 |
+
Returns:
|
178 |
+
- (:obj:`TrainingReturn`): The training result, of which the attributions are:
|
179 |
+
- wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
|
180 |
+
"""
|
181 |
+
|
182 |
+
if debug:
|
183 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
184 |
+
logging.debug(self.policy._model)
|
185 |
+
# define env and policy
|
186 |
+
collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
|
187 |
+
evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
|
188 |
+
collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
|
189 |
+
evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
|
190 |
+
|
191 |
+
with task.start(ctx=OnlineRLContext()):
|
192 |
+
task.use(
|
193 |
+
interaction_evaluator(
|
194 |
+
self.cfg,
|
195 |
+
self.policy.eval_mode,
|
196 |
+
evaluator_env,
|
197 |
+
render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
|
198 |
+
)
|
199 |
+
)
|
200 |
+
task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
|
201 |
+
task.use(
|
202 |
+
StepCollector(
|
203 |
+
self.cfg,
|
204 |
+
self.policy.collect_mode,
|
205 |
+
collector_env,
|
206 |
+
random_collect_size=self.cfg.policy.random_collect_size
|
207 |
+
if hasattr(self.cfg.policy, 'random_collect_size') else 0,
|
208 |
+
)
|
209 |
+
)
|
210 |
+
task.use(data_pusher(self.cfg, self.buffer_))
|
211 |
+
task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
|
212 |
+
task.use(
|
213 |
+
wandb_online_logger(
|
214 |
+
metric_list=self.policy._monitor_vars_learn(),
|
215 |
+
model=self.policy._model,
|
216 |
+
anonymous=True,
|
217 |
+
project_name=self.exp_name,
|
218 |
+
wandb_sweep=wandb_sweep,
|
219 |
+
)
|
220 |
+
)
|
221 |
+
task.use(termination_checker(max_env_step=step))
|
222 |
+
task.use(final_ctx_saver(name=self.exp_name))
|
223 |
+
task.run()
|
224 |
+
|
225 |
+
return TrainingReturn(wandb_url=task.ctx.wandb_url)
|
226 |
+
|
227 |
+
def deploy(
|
228 |
+
self,
|
229 |
+
enable_save_replay: bool = False,
|
230 |
+
concatenate_all_replay: bool = False,
|
231 |
+
replay_save_path: str = None,
|
232 |
+
seed: Optional[Union[int, List]] = None,
|
233 |
+
debug: bool = False
|
234 |
+
) -> EvalReturn:
|
235 |
+
"""
|
236 |
+
Overview:
|
237 |
+
Deploy the agent with TD3 algorithm by interacting with the environment, during which the replay video \
|
238 |
+
can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
|
239 |
+
Arguments:
|
240 |
+
- enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
|
241 |
+
- concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
|
242 |
+
Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
|
243 |
+
If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
|
244 |
+
the replay video of each episode will be saved separately.
|
245 |
+
- replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
|
246 |
+
If not specified, the video will be saved in ``exp_name/videos``.
|
247 |
+
- seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
|
248 |
+
Default to None. If not specified, ``self.seed`` will be used. \
|
249 |
+
If ``seed`` is an integer, the agent will be deployed once. \
|
250 |
+
If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
|
251 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
252 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
253 |
+
subprocess environment manager will be used.
|
254 |
+
Returns:
|
255 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
256 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
257 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
258 |
+
"""
|
259 |
+
|
260 |
+
if debug:
|
261 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
262 |
+
# define env and policy
|
263 |
+
env = self.env.clone(caller='evaluator')
|
264 |
+
|
265 |
+
if seed is not None and isinstance(seed, int):
|
266 |
+
seeds = [seed]
|
267 |
+
elif seed is not None and isinstance(seed, list):
|
268 |
+
seeds = seed
|
269 |
+
else:
|
270 |
+
seeds = [self.seed]
|
271 |
+
|
272 |
+
returns = []
|
273 |
+
images = []
|
274 |
+
if enable_save_replay:
|
275 |
+
replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
|
276 |
+
env.enable_save_replay(replay_path=replay_save_path)
|
277 |
+
else:
|
278 |
+
logging.warning('No video would be generated during the deploy.')
|
279 |
+
if concatenate_all_replay:
|
280 |
+
logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
|
281 |
+
concatenate_all_replay = False
|
282 |
+
|
283 |
+
def single_env_forward_wrapper(forward_fn, cuda=True):
|
284 |
+
|
285 |
+
def _forward(obs):
|
286 |
+
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
|
287 |
+
obs = ttorch.as_tensor(obs).unsqueeze(0)
|
288 |
+
if cuda and torch.cuda.is_available():
|
289 |
+
obs = obs.cuda()
|
290 |
+
action = forward_fn(obs, mode='compute_actor')["action"]
|
291 |
+
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
|
292 |
+
action = action.squeeze(0).detach().cpu().numpy()
|
293 |
+
return action
|
294 |
+
|
295 |
+
return _forward
|
296 |
+
|
297 |
+
forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
|
298 |
+
|
299 |
+
# reset first to make sure the env is in the initial state
|
300 |
+
# env will be reset again in the main loop
|
301 |
+
env.reset()
|
302 |
+
|
303 |
+
for seed in seeds:
|
304 |
+
env.seed(seed, dynamic_seed=False)
|
305 |
+
return_ = 0.
|
306 |
+
step = 0
|
307 |
+
obs = env.reset()
|
308 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
309 |
+
while True:
|
310 |
+
action = forward_fn(obs)
|
311 |
+
obs, rew, done, info = env.step(action)
|
312 |
+
images.append(render(env)[None]) if concatenate_all_replay else None
|
313 |
+
return_ += rew
|
314 |
+
step += 1
|
315 |
+
if done:
|
316 |
+
break
|
317 |
+
logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
|
318 |
+
returns.append(return_)
|
319 |
+
|
320 |
+
env.close()
|
321 |
+
|
322 |
+
if concatenate_all_replay:
|
323 |
+
images = np.concatenate(images, axis=0)
|
324 |
+
import imageio
|
325 |
+
imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
|
326 |
+
|
327 |
+
return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
|
328 |
+
|
329 |
+
def collect_data(
|
330 |
+
self,
|
331 |
+
env_num: int = 8,
|
332 |
+
save_data_path: Optional[str] = None,
|
333 |
+
n_sample: Optional[int] = None,
|
334 |
+
n_episode: Optional[int] = None,
|
335 |
+
context: Optional[str] = None,
|
336 |
+
debug: bool = False
|
337 |
+
) -> None:
|
338 |
+
"""
|
339 |
+
Overview:
|
340 |
+
Collect data with TD3 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
|
341 |
+
The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
|
342 |
+
``exp_name/demo_data``.
|
343 |
+
Arguments:
|
344 |
+
- env_num (:obj:`int`): The number of collector environments. Default to 8.
|
345 |
+
- save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
|
346 |
+
If not specified, the data will be saved in ``exp_name/demo_data``.
|
347 |
+
- n_sample (:obj:`int`): The number of samples to collect. Default to None. \
|
348 |
+
If not specified, ``n_episode`` must be specified.
|
349 |
+
- n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
|
350 |
+
If not specified, ``n_sample`` must be specified.
|
351 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
352 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
353 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
354 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
355 |
+
subprocess environment manager will be used.
|
356 |
+
"""
|
357 |
+
|
358 |
+
if debug:
|
359 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
360 |
+
if n_episode is not None:
|
361 |
+
raise NotImplementedError
|
362 |
+
# define env and policy
|
363 |
+
env_num = env_num if env_num else self.cfg.env.collector_env_num
|
364 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
|
365 |
+
|
366 |
+
if save_data_path is None:
|
367 |
+
save_data_path = os.path.join(self.exp_name, 'demo_data')
|
368 |
+
|
369 |
+
# main execution task
|
370 |
+
with task.start(ctx=OnlineRLContext()):
|
371 |
+
task.use(
|
372 |
+
StepCollector(
|
373 |
+
self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
|
374 |
+
)
|
375 |
+
)
|
376 |
+
task.use(offline_data_saver(save_data_path, data_type='hdf5'))
|
377 |
+
task.run(max_step=1)
|
378 |
+
logging.info(
|
379 |
+
f'TD3 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
|
380 |
+
)
|
381 |
+
|
382 |
+
def batch_evaluate(
|
383 |
+
self,
|
384 |
+
env_num: int = 4,
|
385 |
+
n_evaluator_episode: int = 4,
|
386 |
+
context: Optional[str] = None,
|
387 |
+
debug: bool = False
|
388 |
+
) -> EvalReturn:
|
389 |
+
"""
|
390 |
+
Overview:
|
391 |
+
Evaluate the agent with TD3 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
|
392 |
+
environments. The evaluation result will be returned.
|
393 |
+
The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
|
394 |
+
multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
|
395 |
+
will only create one evaluator environment to evaluate the agent and save the replay video.
|
396 |
+
Arguments:
|
397 |
+
- env_num (:obj:`int`): The number of evaluator environments. Default to 4.
|
398 |
+
- n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
|
399 |
+
- context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
|
400 |
+
It can be specified as ``spawn``, ``fork`` or ``forkserver``.
|
401 |
+
- debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
|
402 |
+
If set True, base environment manager will be used for easy debugging. Otherwise, \
|
403 |
+
subprocess environment manager will be used.
|
404 |
+
Returns:
|
405 |
+
- (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
|
406 |
+
- eval_value (:obj:`np.float32`): The mean of evaluation return.
|
407 |
+
- eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
|
408 |
+
"""
|
409 |
+
|
410 |
+
if debug:
|
411 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
412 |
+
# define env and policy
|
413 |
+
env_num = env_num if env_num else self.cfg.env.evaluator_env_num
|
414 |
+
env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
|
415 |
+
|
416 |
+
# reset first to make sure the env is in the initial state
|
417 |
+
# env will be reset again in the main loop
|
418 |
+
env.launch()
|
419 |
+
env.reset()
|
420 |
+
|
421 |
+
evaluate_cfg = self.cfg
|
422 |
+
evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
|
423 |
+
|
424 |
+
# main execution task
|
425 |
+
with task.start(ctx=OnlineRLContext()):
|
426 |
+
task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
|
427 |
+
task.run(max_step=1)
|
428 |
+
|
429 |
+
return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
|
430 |
+
|
431 |
+
@property
|
432 |
+
def best(self) -> 'TD3Agent':
|
433 |
+
"""
|
434 |
+
Overview:
|
435 |
+
Load the best model from the checkpoint directory, \
|
436 |
+
which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
|
437 |
+
The return value is the agent with the best model.
|
438 |
+
Returns:
|
439 |
+
- (:obj:`TD3Agent`): The agent with the best model.
|
440 |
+
Examples:
|
441 |
+
>>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
|
442 |
+
>>> agent.train()
|
443 |
+
>>> agent.best
|
444 |
+
|
445 |
+
.. note::
|
446 |
+
The best model is the model with the highest evaluation return. If this method is called, the current \
|
447 |
+
model will be replaced by the best model.
|
448 |
+
"""
|
449 |
+
|
450 |
+
best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
|
451 |
+
# Load best model if it exists
|
452 |
+
if os.path.exists(best_model_file_path):
|
453 |
+
policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
|
454 |
+
self.policy.learn_mode.load_state_dict(policy_state_dict)
|
455 |
+
return self
|
DI-engine/ding/compatibility.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def torch_ge_131():
|
5 |
+
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
|
6 |
+
|
7 |
+
|
8 |
+
def torch_ge_180():
|
9 |
+
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180
|
DI-engine/ding/config/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import Config, read_config, save_config, compile_config, compile_config_parallel, read_config_directly, \
|
2 |
+
read_config_with_system, save_config_py
|
3 |
+
from .utils import parallel_transform, parallel_transform_slurm
|
4 |
+
from .example import A2C, C51, DDPG, DQN, PG, PPOF, PPOOffPolicy, SAC, SQL, TD3
|
DI-engine/ding/config/config.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import yaml
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
import tempfile
|
9 |
+
import subprocess
|
10 |
+
import datetime
|
11 |
+
from importlib import import_module
|
12 |
+
from typing import Optional, Tuple
|
13 |
+
from easydict import EasyDict
|
14 |
+
from copy import deepcopy
|
15 |
+
|
16 |
+
from ding.utils import deep_merge_dicts, get_rank
|
17 |
+
from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
|
18 |
+
from ding.policy import get_policy_cls
|
19 |
+
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
|
20 |
+
AdvancedReplayBuffer, get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, \
|
21 |
+
get_serial_collector_cls, MetricSerialEvaluator, BattleInteractionSerialEvaluator
|
22 |
+
from ding.reward_model import get_reward_model_cls
|
23 |
+
from ding.world_model import get_world_model_cls
|
24 |
+
from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted
|
25 |
+
|
26 |
+
|
27 |
+
class Config(object):
|
28 |
+
r"""
|
29 |
+
Overview:
|
30 |
+
Base class for config.
|
31 |
+
Interface:
|
32 |
+
__init__, file_to_dict
|
33 |
+
Property:
|
34 |
+
cfg_dict
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
cfg_dict: Optional[dict] = None,
|
40 |
+
cfg_text: Optional[str] = None,
|
41 |
+
filename: Optional[str] = None
|
42 |
+
) -> None:
|
43 |
+
"""
|
44 |
+
Overview:
|
45 |
+
Init method. Create config including dict type config and text type config.
|
46 |
+
Arguments:
|
47 |
+
- cfg_dict (:obj:`Optional[dict]`): dict type config
|
48 |
+
- cfg_text (:obj:`Optional[str]`): text type config
|
49 |
+
- filename (:obj:`Optional[str]`): config file name
|
50 |
+
"""
|
51 |
+
if cfg_dict is None:
|
52 |
+
cfg_dict = {}
|
53 |
+
if not isinstance(cfg_dict, dict):
|
54 |
+
raise TypeError("invalid type for cfg_dict: {}".format(type(cfg_dict)))
|
55 |
+
self._cfg_dict = cfg_dict
|
56 |
+
if cfg_text:
|
57 |
+
text = cfg_text
|
58 |
+
elif filename:
|
59 |
+
with open(filename, 'r') as f:
|
60 |
+
text = f.read()
|
61 |
+
else:
|
62 |
+
text = '.'
|
63 |
+
self._text = text
|
64 |
+
self._filename = filename
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def file_to_dict(filename: str) -> 'Config': # noqa
|
68 |
+
"""
|
69 |
+
Overview:
|
70 |
+
Read config file and create config.
|
71 |
+
Arguments:
|
72 |
+
- filename (:obj:`Optional[str]`): config file name.
|
73 |
+
Returns:
|
74 |
+
- cfg_dict (:obj:`Config`): config class
|
75 |
+
"""
|
76 |
+
cfg_dict, cfg_text = Config._file_to_dict(filename)
|
77 |
+
return Config(cfg_dict, cfg_text, filename=filename)
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def _file_to_dict(filename: str) -> Tuple[dict, str]:
|
81 |
+
"""
|
82 |
+
Overview:
|
83 |
+
Read config file and convert the config file to dict type config and text type config.
|
84 |
+
Arguments:
|
85 |
+
- filename (:obj:`Optional[str]`): config file name.
|
86 |
+
Returns:
|
87 |
+
- cfg_dict (:obj:`Optional[dict]`): dict type config
|
88 |
+
- cfg_text (:obj:`Optional[str]`): text type config
|
89 |
+
"""
|
90 |
+
filename = osp.abspath(osp.expanduser(filename))
|
91 |
+
# TODO check exist
|
92 |
+
# TODO check suffix
|
93 |
+
ext_name = osp.splitext(filename)[-1]
|
94 |
+
with tempfile.TemporaryDirectory() as temp_config_dir:
|
95 |
+
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=ext_name)
|
96 |
+
temp_config_name = osp.basename(temp_config_file.name)
|
97 |
+
temp_config_file.close()
|
98 |
+
shutil.copyfile(filename, temp_config_file.name)
|
99 |
+
|
100 |
+
temp_module_name = osp.splitext(temp_config_name)[0]
|
101 |
+
sys.path.insert(0, temp_config_dir)
|
102 |
+
# TODO validate py syntax
|
103 |
+
module = import_module(temp_module_name)
|
104 |
+
cfg_dict = {k: v for k, v in module.__dict__.items() if not k.startswith('_')}
|
105 |
+
del sys.modules[temp_module_name]
|
106 |
+
sys.path.pop(0)
|
107 |
+
|
108 |
+
cfg_text = filename + '\n'
|
109 |
+
with open(filename, 'r') as f:
|
110 |
+
cfg_text += f.read()
|
111 |
+
|
112 |
+
return cfg_dict, cfg_text
|
113 |
+
|
114 |
+
@property
|
115 |
+
def cfg_dict(self) -> dict:
|
116 |
+
return self._cfg_dict
|
117 |
+
|
118 |
+
|
119 |
+
def read_config_yaml(path: str) -> EasyDict:
|
120 |
+
"""
|
121 |
+
Overview:
|
122 |
+
read configuration from path
|
123 |
+
Arguments:
|
124 |
+
- path (:obj:`str`): Path of source yaml
|
125 |
+
Returns:
|
126 |
+
- (:obj:`EasyDict`): Config data from this file with dict type
|
127 |
+
"""
|
128 |
+
with open(path, "r") as f:
|
129 |
+
config_ = yaml.safe_load(f)
|
130 |
+
|
131 |
+
return EasyDict(config_)
|
132 |
+
|
133 |
+
|
134 |
+
def save_config_yaml(config_: dict, path: str) -> None:
|
135 |
+
"""
|
136 |
+
Overview:
|
137 |
+
save configuration to path
|
138 |
+
Arguments:
|
139 |
+
- config (:obj:`dict`): Config dict
|
140 |
+
- path (:obj:`str`): Path of target yaml
|
141 |
+
"""
|
142 |
+
config_string = json.dumps(config_)
|
143 |
+
with open(path, "w") as f:
|
144 |
+
yaml.safe_dump(json.loads(config_string), f)
|
145 |
+
|
146 |
+
|
147 |
+
def save_config_py(config_: dict, path: str) -> None:
|
148 |
+
"""
|
149 |
+
Overview:
|
150 |
+
save configuration to python file
|
151 |
+
Arguments:
|
152 |
+
- config (:obj:`dict`): Config dict
|
153 |
+
- path (:obj:`str`): Path of target yaml
|
154 |
+
"""
|
155 |
+
# config_string = json.dumps(config_, indent=4)
|
156 |
+
config_string = str(config_)
|
157 |
+
from yapf.yapflib.yapf_api import FormatCode
|
158 |
+
config_string, _ = FormatCode(config_string)
|
159 |
+
config_string = config_string.replace('inf,', 'float("inf"),')
|
160 |
+
with open(path, "w") as f:
|
161 |
+
f.write('exp_config = ' + config_string)
|
162 |
+
|
163 |
+
|
164 |
+
def read_config_directly(path: str) -> dict:
|
165 |
+
"""
|
166 |
+
Overview:
|
167 |
+
Read configuration from a file path(now only support python file) and directly return results.
|
168 |
+
Arguments:
|
169 |
+
- path (:obj:`str`): Path of configuration file
|
170 |
+
Returns:
|
171 |
+
- cfg (:obj:`Tuple[dict, dict]`): Configuration dict.
|
172 |
+
"""
|
173 |
+
suffix = path.split('.')[-1]
|
174 |
+
if suffix == 'py':
|
175 |
+
return Config.file_to_dict(path).cfg_dict
|
176 |
+
else:
|
177 |
+
raise KeyError("invalid config file suffix: {}".format(suffix))
|
178 |
+
|
179 |
+
|
180 |
+
def read_config(path: str) -> Tuple[dict, dict]:
|
181 |
+
"""
|
182 |
+
Overview:
|
183 |
+
Read configuration from a file path(now only suport python file). And select some proper parts.
|
184 |
+
Arguments:
|
185 |
+
- path (:obj:`str`): Path of configuration file
|
186 |
+
Returns:
|
187 |
+
- cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config` and \
|
188 |
+
`create_cfg` two parts.
|
189 |
+
"""
|
190 |
+
suffix = path.split('.')[-1]
|
191 |
+
if suffix == 'py':
|
192 |
+
cfg = Config.file_to_dict(path).cfg_dict
|
193 |
+
assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
|
194 |
+
assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
|
195 |
+
return cfg['main_config'], cfg['create_config']
|
196 |
+
else:
|
197 |
+
raise KeyError("invalid config file suffix: {}".format(suffix))
|
198 |
+
|
199 |
+
|
200 |
+
def read_config_with_system(path: str) -> Tuple[dict, dict, dict]:
|
201 |
+
"""
|
202 |
+
Overview:
|
203 |
+
Read configuration from a file path(now only suport python file). And select some proper parts
|
204 |
+
Arguments:
|
205 |
+
- path (:obj:`str`): Path of configuration file
|
206 |
+
Returns:
|
207 |
+
- cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config`, \
|
208 |
+
`create_cfg` and `system_config` three parts.
|
209 |
+
"""
|
210 |
+
suffix = path.split('.')[-1]
|
211 |
+
if suffix == 'py':
|
212 |
+
cfg = Config.file_to_dict(path).cfg_dict
|
213 |
+
assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
|
214 |
+
assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
|
215 |
+
assert "system_config" in cfg, "Please make sure a 'system_config' variable is declared in config python file!"
|
216 |
+
return cfg['main_config'], cfg['create_config'], cfg['system_config']
|
217 |
+
else:
|
218 |
+
raise KeyError("invalid config file suffix: {}".format(suffix))
|
219 |
+
|
220 |
+
|
221 |
+
def save_config(config_: dict, path: str, type_: str = 'py', save_formatted: bool = False) -> None:
|
222 |
+
"""
|
223 |
+
Overview:
|
224 |
+
save configuration to python file or yaml file
|
225 |
+
Arguments:
|
226 |
+
- config (:obj:`dict`): Config dict
|
227 |
+
- path (:obj:`str`): Path of target yaml or target python file
|
228 |
+
- type (:obj:`str`): If type is ``yaml`` , save configuration to yaml file. If type is ``py`` , save\
|
229 |
+
configuration to python file.
|
230 |
+
- save_formatted (:obj:`bool`): If save_formatted is true, save formatted config to path.\
|
231 |
+
Formatted config can be read by serial_pipeline directly.
|
232 |
+
"""
|
233 |
+
assert type_ in ['yaml', 'py'], type_
|
234 |
+
if type_ == 'yaml':
|
235 |
+
save_config_yaml(config_, path)
|
236 |
+
elif type_ == 'py':
|
237 |
+
save_config_py(config_, path)
|
238 |
+
if save_formatted:
|
239 |
+
formated_path = osp.join(osp.dirname(path), 'formatted_' + osp.basename(path))
|
240 |
+
save_config_formatted(config_, formated_path)
|
241 |
+
|
242 |
+
|
243 |
+
def compile_buffer_config(policy_cfg: EasyDict, user_cfg: EasyDict, buffer_cls: 'IBuffer') -> EasyDict: # noqa
|
244 |
+
|
245 |
+
def _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls):
|
246 |
+
|
247 |
+
if buffer_cls is None:
|
248 |
+
assert 'type' in policy_buffer_cfg, "please indicate buffer type in create_cfg"
|
249 |
+
buffer_cls = get_buffer_cls(policy_buffer_cfg)
|
250 |
+
buffer_cfg = deep_merge_dicts(buffer_cls.default_config(), policy_buffer_cfg)
|
251 |
+
buffer_cfg = deep_merge_dicts(buffer_cfg, user_buffer_cfg)
|
252 |
+
return buffer_cfg
|
253 |
+
|
254 |
+
policy_multi_buffer = policy_cfg.other.replay_buffer.get('multi_buffer', False)
|
255 |
+
user_multi_buffer = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('multi_buffer', False)
|
256 |
+
assert not user_multi_buffer or user_multi_buffer == policy_multi_buffer, "For multi_buffer, \
|
257 |
+
user_cfg({}) and policy_cfg({}) must be in accordance".format(user_multi_buffer, policy_multi_buffer)
|
258 |
+
multi_buffer = policy_multi_buffer
|
259 |
+
if not multi_buffer:
|
260 |
+
policy_buffer_cfg = policy_cfg.other.replay_buffer
|
261 |
+
user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {})
|
262 |
+
return _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls)
|
263 |
+
else:
|
264 |
+
return_cfg = EasyDict()
|
265 |
+
for buffer_name in policy_cfg.other.replay_buffer: # Only traverse keys in policy_cfg
|
266 |
+
if buffer_name == 'multi_buffer':
|
267 |
+
continue
|
268 |
+
policy_buffer_cfg = policy_cfg.other.replay_buffer[buffer_name]
|
269 |
+
user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('buffer_name', {})
|
270 |
+
if buffer_cls is None:
|
271 |
+
return_cfg[buffer_name] = _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, None)
|
272 |
+
else:
|
273 |
+
return_cfg[buffer_name] = _compile_buffer_config(
|
274 |
+
policy_buffer_cfg, user_buffer_cfg, buffer_cls[buffer_name]
|
275 |
+
)
|
276 |
+
return_cfg[buffer_name].name = buffer_name
|
277 |
+
return return_cfg
|
278 |
+
|
279 |
+
|
280 |
+
def compile_collector_config(
|
281 |
+
policy_cfg: EasyDict,
|
282 |
+
user_cfg: EasyDict,
|
283 |
+
collector_cls: 'ISerialCollector' # noqa
|
284 |
+
) -> EasyDict:
|
285 |
+
policy_collector_cfg = policy_cfg.collect.collector
|
286 |
+
user_collector_cfg = user_cfg.policy.get('collect', {}).get('collector', {})
|
287 |
+
# step1: get collector class
|
288 |
+
# two cases: create cfg merged in policy_cfg, collector class, and class has higher priority
|
289 |
+
if collector_cls is None:
|
290 |
+
assert 'type' in policy_collector_cfg, "please indicate collector type in create_cfg"
|
291 |
+
# use type to get collector_cls
|
292 |
+
collector_cls = get_serial_collector_cls(policy_collector_cfg)
|
293 |
+
# step2: policy collector cfg merge to collector cfg
|
294 |
+
collector_cfg = deep_merge_dicts(collector_cls.default_config(), policy_collector_cfg)
|
295 |
+
# step3: user collector cfg merge to the step2 config
|
296 |
+
collector_cfg = deep_merge_dicts(collector_cfg, user_collector_cfg)
|
297 |
+
|
298 |
+
return collector_cfg
|
299 |
+
|
300 |
+
|
301 |
+
policy_config_template = dict(
|
302 |
+
model=dict(),
|
303 |
+
learn=dict(learner=dict()),
|
304 |
+
collect=dict(collector=dict()),
|
305 |
+
eval=dict(evaluator=dict()),
|
306 |
+
other=dict(replay_buffer=dict()),
|
307 |
+
)
|
308 |
+
policy_config_template = EasyDict(policy_config_template)
|
309 |
+
env_config_template = dict(manager=dict(), stop_value=int(1e10), n_evaluator_episode=4)
|
310 |
+
env_config_template = EasyDict(env_config_template)
|
311 |
+
|
312 |
+
|
313 |
+
def save_project_state(exp_name: str) -> None:
|
314 |
+
|
315 |
+
def _fn(cmd: str):
|
316 |
+
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
|
317 |
+
|
318 |
+
if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
|
319 |
+
short_sha = _fn("git describe --always")
|
320 |
+
log = _fn("git log --stat -n 5")
|
321 |
+
diff = _fn("git diff")
|
322 |
+
with open(os.path.join(exp_name, "git_log.txt"), "w", encoding='utf-8') as f:
|
323 |
+
f.write(short_sha + '\n\n' + log)
|
324 |
+
with open(os.path.join(exp_name, "git_diff.txt"), "w", encoding='utf-8') as f:
|
325 |
+
f.write(diff)
|
326 |
+
|
327 |
+
|
328 |
+
def compile_config(
|
329 |
+
cfg: EasyDict,
|
330 |
+
env_manager: type = None,
|
331 |
+
policy: type = None,
|
332 |
+
learner: type = BaseLearner,
|
333 |
+
collector: type = None,
|
334 |
+
evaluator: type = InteractionSerialEvaluator,
|
335 |
+
buffer: type = None,
|
336 |
+
env: type = None,
|
337 |
+
reward_model: type = None,
|
338 |
+
world_model: type = None,
|
339 |
+
seed: int = 0,
|
340 |
+
auto: bool = False,
|
341 |
+
create_cfg: dict = None,
|
342 |
+
save_cfg: bool = True,
|
343 |
+
save_path: str = 'total_config.py',
|
344 |
+
renew_dir: bool = True,
|
345 |
+
) -> EasyDict:
|
346 |
+
"""
|
347 |
+
Overview:
|
348 |
+
Combine the input config information with other input information.
|
349 |
+
Compile config to make it easy to be called by other programs
|
350 |
+
Arguments:
|
351 |
+
- cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline
|
352 |
+
- env_manager (:obj:`type`): Env_manager class which is to be used in the following pipeline
|
353 |
+
- policy (:obj:`type`): Policy class which is to be used in the following pipeline
|
354 |
+
- learner (:obj:`type`): Input learner class, defaults to BaseLearner
|
355 |
+
- collector (:obj:`type`): Input collector class, defaults to BaseSerialCollector
|
356 |
+
- evaluator (:obj:`type`): Input evaluator class, defaults to InteractionSerialEvaluator
|
357 |
+
- buffer (:obj:`type`): Input buffer class, defaults to IBuffer
|
358 |
+
- env (:obj:`type`): Environment class which is to be used in the following pipeline
|
359 |
+
- reward_model (:obj:`type`): Reward model class which aims to offer various and valuable reward
|
360 |
+
- seed (:obj:`int`): Random number seed
|
361 |
+
- auto (:obj:`bool`): Compile create_config dict or not
|
362 |
+
- create_cfg (:obj:`dict`): Input create config dict
|
363 |
+
- save_cfg (:obj:`bool`): Save config or not
|
364 |
+
- save_path (:obj:`str`): Path of saving file
|
365 |
+
- renew_dir (:obj:`bool`): Whether to new a directory for saving config.
|
366 |
+
Returns:
|
367 |
+
- cfg (:obj:`EasyDict`): Config after compiling
|
368 |
+
"""
|
369 |
+
cfg, create_cfg = deepcopy(cfg), deepcopy(create_cfg)
|
370 |
+
if auto:
|
371 |
+
assert create_cfg is not None
|
372 |
+
# for compatibility
|
373 |
+
if 'collector' not in create_cfg:
|
374 |
+
create_cfg.collector = EasyDict(dict(type='sample'))
|
375 |
+
if 'replay_buffer' not in create_cfg:
|
376 |
+
create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
|
377 |
+
buffer = AdvancedReplayBuffer
|
378 |
+
if env is None:
|
379 |
+
if 'env' in create_cfg:
|
380 |
+
env = get_env_cls(create_cfg.env)
|
381 |
+
else:
|
382 |
+
env = None
|
383 |
+
create_cfg.env = {'type': 'ding_env_wrapper_generated'}
|
384 |
+
if env_manager is None:
|
385 |
+
env_manager = get_env_manager_cls(create_cfg.env_manager)
|
386 |
+
if policy is None:
|
387 |
+
policy = get_policy_cls(create_cfg.policy)
|
388 |
+
if 'default_config' in dir(env):
|
389 |
+
env_config = env.default_config()
|
390 |
+
else:
|
391 |
+
env_config = EasyDict() # env does not have default_config
|
392 |
+
env_config = deep_merge_dicts(env_config_template, env_config)
|
393 |
+
env_config.update(create_cfg.env)
|
394 |
+
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
|
395 |
+
env_config.manager.update(create_cfg.env_manager)
|
396 |
+
policy_config = policy.default_config()
|
397 |
+
policy_config = deep_merge_dicts(policy_config_template, policy_config)
|
398 |
+
policy_config.update(create_cfg.policy)
|
399 |
+
policy_config.collect.collector.update(create_cfg.collector)
|
400 |
+
if 'evaluator' in create_cfg:
|
401 |
+
policy_config.eval.evaluator.update(create_cfg.evaluator)
|
402 |
+
policy_config.other.replay_buffer.update(create_cfg.replay_buffer)
|
403 |
+
|
404 |
+
policy_config.other.commander = BaseSerialCommander.default_config()
|
405 |
+
if 'reward_model' in create_cfg:
|
406 |
+
reward_model = get_reward_model_cls(create_cfg.reward_model)
|
407 |
+
reward_model_config = reward_model.default_config()
|
408 |
+
else:
|
409 |
+
reward_model_config = EasyDict()
|
410 |
+
if 'world_model' in create_cfg:
|
411 |
+
world_model = get_world_model_cls(create_cfg.world_model)
|
412 |
+
world_model_config = world_model.default_config()
|
413 |
+
world_model_config.update(create_cfg.world_model)
|
414 |
+
else:
|
415 |
+
world_model_config = EasyDict()
|
416 |
+
else:
|
417 |
+
if 'default_config' in dir(env):
|
418 |
+
env_config = env.default_config()
|
419 |
+
else:
|
420 |
+
env_config = EasyDict() # env does not have default_config
|
421 |
+
env_config = deep_merge_dicts(env_config_template, env_config)
|
422 |
+
if env_manager is None:
|
423 |
+
env_manager = BaseEnvManager # for compatibility
|
424 |
+
env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
|
425 |
+
policy_config = policy.default_config()
|
426 |
+
policy_config = deep_merge_dicts(policy_config_template, policy_config)
|
427 |
+
if reward_model is None:
|
428 |
+
reward_model_config = EasyDict()
|
429 |
+
else:
|
430 |
+
reward_model_config = reward_model.default_config()
|
431 |
+
if world_model is None:
|
432 |
+
world_model_config = EasyDict()
|
433 |
+
else:
|
434 |
+
world_model_config = world_model.default_config()
|
435 |
+
world_model_config.update(create_cfg.world_model)
|
436 |
+
policy_config.learn.learner = deep_merge_dicts(
|
437 |
+
learner.default_config(),
|
438 |
+
policy_config.learn.learner,
|
439 |
+
)
|
440 |
+
if create_cfg is not None or collector is not None:
|
441 |
+
policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector)
|
442 |
+
if evaluator:
|
443 |
+
policy_config.eval.evaluator = deep_merge_dicts(
|
444 |
+
evaluator.default_config(),
|
445 |
+
policy_config.eval.evaluator,
|
446 |
+
)
|
447 |
+
if create_cfg is not None or buffer is not None:
|
448 |
+
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer)
|
449 |
+
default_config = EasyDict({'env': env_config, 'policy': policy_config})
|
450 |
+
if len(reward_model_config) > 0:
|
451 |
+
default_config['reward_model'] = reward_model_config
|
452 |
+
if len(world_model_config) > 0:
|
453 |
+
default_config['world_model'] = world_model_config
|
454 |
+
cfg = deep_merge_dicts(default_config, cfg)
|
455 |
+
if 'unroll_len' in cfg.policy:
|
456 |
+
cfg.policy.collect.unroll_len = cfg.policy.unroll_len
|
457 |
+
cfg.seed = seed
|
458 |
+
# check important key in config
|
459 |
+
if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
|
460 |
+
cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
|
461 |
+
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
|
462 |
+
if 'exp_name' not in cfg:
|
463 |
+
cfg.exp_name = 'default_experiment'
|
464 |
+
if save_cfg and get_rank() == 0:
|
465 |
+
if os.path.exists(cfg.exp_name) and renew_dir:
|
466 |
+
cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
|
467 |
+
try:
|
468 |
+
os.makedirs(cfg.exp_name)
|
469 |
+
except FileExistsError:
|
470 |
+
pass
|
471 |
+
save_project_state(cfg.exp_name)
|
472 |
+
save_path = os.path.join(cfg.exp_name, save_path)
|
473 |
+
save_config(cfg, save_path, save_formatted=True)
|
474 |
+
return cfg
|
475 |
+
|
476 |
+
|
477 |
+
def compile_config_parallel(
|
478 |
+
cfg: EasyDict,
|
479 |
+
create_cfg: EasyDict,
|
480 |
+
system_cfg: EasyDict,
|
481 |
+
seed: int = 0,
|
482 |
+
save_cfg: bool = True,
|
483 |
+
save_path: str = 'total_config.py',
|
484 |
+
platform: str = 'local',
|
485 |
+
coordinator_host: Optional[str] = None,
|
486 |
+
learner_host: Optional[str] = None,
|
487 |
+
collector_host: Optional[str] = None,
|
488 |
+
coordinator_port: Optional[int] = None,
|
489 |
+
learner_port: Optional[int] = None,
|
490 |
+
collector_port: Optional[int] = None,
|
491 |
+
) -> EasyDict:
|
492 |
+
"""
|
493 |
+
Overview:
|
494 |
+
Combine the input parallel mode configuration information with other input information. Compile config\
|
495 |
+
to make it easy to be called by other programs
|
496 |
+
Arguments:
|
497 |
+
- cfg (:obj:`EasyDict`): Input main config dict
|
498 |
+
- create_cfg (:obj:`dict`): Input create config dict, including type parameters, such as environment type
|
499 |
+
- system_cfg (:obj:`dict`): Input system config dict, including system parameters, such as file path,\
|
500 |
+
communication mode, use multiple GPUs or not
|
501 |
+
- seed (:obj:`int`): Random number seed
|
502 |
+
- save_cfg (:obj:`bool`): Save config or not
|
503 |
+
- save_path (:obj:`str`): Path of saving file
|
504 |
+
- platform (:obj:`str`): Where to run the program, 'local' or 'slurm'
|
505 |
+
- coordinator_host (:obj:`Optional[str]`): Input coordinator's host when platform is slurm
|
506 |
+
- learner_host (:obj:`Optional[str]`): Input learner's host when platform is slurm
|
507 |
+
- collector_host (:obj:`Optional[str]`): Input collector's host when platform is slurm
|
508 |
+
Returns:
|
509 |
+
- cfg (:obj:`EasyDict`): Config after compiling
|
510 |
+
"""
|
511 |
+
# for compatibility
|
512 |
+
if 'replay_buffer' not in create_cfg:
|
513 |
+
create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
|
514 |
+
# env
|
515 |
+
env = get_env_cls(create_cfg.env)
|
516 |
+
if 'default_config' in dir(env):
|
517 |
+
env_config = env.default_config()
|
518 |
+
else:
|
519 |
+
env_config = EasyDict() # env does not have default_config
|
520 |
+
env_config = deep_merge_dicts(env_config_template, env_config)
|
521 |
+
env_config.update(create_cfg.env)
|
522 |
+
|
523 |
+
env_manager = get_env_manager_cls(create_cfg.env_manager)
|
524 |
+
env_config.manager = env_manager.default_config()
|
525 |
+
env_config.manager.update(create_cfg.env_manager)
|
526 |
+
|
527 |
+
# policy
|
528 |
+
policy = get_policy_cls(create_cfg.policy)
|
529 |
+
policy_config = policy.default_config()
|
530 |
+
policy_config = deep_merge_dicts(policy_config_template, policy_config)
|
531 |
+
cfg.policy.update(create_cfg.policy)
|
532 |
+
|
533 |
+
collector = get_parallel_collector_cls(create_cfg.collector)
|
534 |
+
policy_config.collect.collector = collector.default_config()
|
535 |
+
policy_config.collect.collector.update(create_cfg.collector)
|
536 |
+
policy_config.learn.learner = BaseLearner.default_config()
|
537 |
+
policy_config.learn.learner.update(create_cfg.learner)
|
538 |
+
commander = get_parallel_commander_cls(create_cfg.commander)
|
539 |
+
policy_config.other.commander = commander.default_config()
|
540 |
+
policy_config.other.commander.update(create_cfg.commander)
|
541 |
+
policy_config.other.replay_buffer.update(create_cfg.replay_buffer)
|
542 |
+
policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, None)
|
543 |
+
|
544 |
+
default_config = EasyDict({'env': env_config, 'policy': policy_config})
|
545 |
+
cfg = deep_merge_dicts(default_config, cfg)
|
546 |
+
|
547 |
+
cfg.policy.other.commander.path_policy = system_cfg.path_policy # league may use 'path_policy'
|
548 |
+
|
549 |
+
# system
|
550 |
+
for k in ['comm_learner', 'comm_collector']:
|
551 |
+
system_cfg[k] = create_cfg[k]
|
552 |
+
if platform == 'local':
|
553 |
+
cfg = parallel_transform(EasyDict({'main': cfg, 'system': system_cfg}))
|
554 |
+
elif platform == 'slurm':
|
555 |
+
cfg = parallel_transform_slurm(
|
556 |
+
EasyDict({
|
557 |
+
'main': cfg,
|
558 |
+
'system': system_cfg
|
559 |
+
}), coordinator_host, learner_host, collector_host
|
560 |
+
)
|
561 |
+
elif platform == 'k8s':
|
562 |
+
cfg = parallel_transform_k8s(
|
563 |
+
EasyDict({
|
564 |
+
'main': cfg,
|
565 |
+
'system': system_cfg
|
566 |
+
}),
|
567 |
+
coordinator_port=coordinator_port,
|
568 |
+
learner_port=learner_port,
|
569 |
+
collector_port=collector_port
|
570 |
+
)
|
571 |
+
else:
|
572 |
+
raise KeyError("not support platform type: {}".format(platform))
|
573 |
+
cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator)
|
574 |
+
# seed
|
575 |
+
cfg.seed = seed
|
576 |
+
|
577 |
+
if save_cfg:
|
578 |
+
save_config(cfg, save_path)
|
579 |
+
return cfg
|
DI-engine/ding/config/example/A2C/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
from . import gym_bipedalwalker_v3
|
3 |
+
from . import gym_lunarlander_v2
|
4 |
+
|
5 |
+
supported_env_cfg = {
|
6 |
+
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
|
7 |
+
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
|
8 |
+
}
|
9 |
+
|
10 |
+
supported_env_cfg = EasyDict(supported_env_cfg)
|
11 |
+
|
12 |
+
supported_env = {
|
13 |
+
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
|
14 |
+
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
|
15 |
+
}
|
16 |
+
|
17 |
+
supported_env = EasyDict(supported_env)
|
DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='Bipedalwalker-v3-A2C',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='BipedalWalker-v3',
|
9 |
+
collector_env_num=8,
|
10 |
+
evaluator_env_num=8,
|
11 |
+
n_evaluator_episode=8,
|
12 |
+
act_scale=True,
|
13 |
+
rew_clip=True,
|
14 |
+
),
|
15 |
+
policy=dict(
|
16 |
+
cuda=True,
|
17 |
+
action_space='continuous',
|
18 |
+
model=dict(
|
19 |
+
action_space='continuous',
|
20 |
+
obs_shape=24,
|
21 |
+
action_shape=4,
|
22 |
+
),
|
23 |
+
learn=dict(
|
24 |
+
batch_size=64,
|
25 |
+
learning_rate=0.0003,
|
26 |
+
value_weight=0.7,
|
27 |
+
entropy_weight=0.0005,
|
28 |
+
discount_factor=0.99,
|
29 |
+
adv_norm=True,
|
30 |
+
),
|
31 |
+
collect=dict(
|
32 |
+
n_sample=64,
|
33 |
+
discount_factor=0.99,
|
34 |
+
),
|
35 |
+
),
|
36 |
+
wandb_logger=dict(
|
37 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
38 |
+
),
|
39 |
+
)
|
40 |
+
|
41 |
+
cfg = EasyDict(cfg)
|
42 |
+
|
43 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='LunarLander-v2-A2C',
|
6 |
+
env=dict(
|
7 |
+
collector_env_num=8,
|
8 |
+
evaluator_env_num=8,
|
9 |
+
env_id='LunarLander-v2',
|
10 |
+
n_evaluator_episode=8,
|
11 |
+
stop_value=260,
|
12 |
+
),
|
13 |
+
policy=dict(
|
14 |
+
cuda=True,
|
15 |
+
model=dict(
|
16 |
+
obs_shape=8,
|
17 |
+
action_shape=4,
|
18 |
+
),
|
19 |
+
learn=dict(
|
20 |
+
batch_size=64,
|
21 |
+
learning_rate=3e-4,
|
22 |
+
entropy_weight=0.001,
|
23 |
+
adv_norm=True,
|
24 |
+
),
|
25 |
+
collect=dict(
|
26 |
+
n_sample=64,
|
27 |
+
discount_factor=0.99,
|
28 |
+
gae_lambda=0.95,
|
29 |
+
),
|
30 |
+
),
|
31 |
+
wandb_logger=dict(
|
32 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
33 |
+
),
|
34 |
+
)
|
35 |
+
|
36 |
+
cfg = EasyDict(cfg)
|
37 |
+
|
38 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/C51/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
from . import gym_lunarlander_v2
|
3 |
+
from . import gym_pongnoframeskip_v4
|
4 |
+
from . import gym_qbertnoframeskip_v4
|
5 |
+
from . import gym_spaceInvadersnoframeskip_v4
|
6 |
+
|
7 |
+
supported_env_cfg = {
|
8 |
+
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
|
9 |
+
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
|
10 |
+
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
|
11 |
+
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
|
12 |
+
}
|
13 |
+
|
14 |
+
supported_env_cfg = EasyDict(supported_env_cfg)
|
15 |
+
|
16 |
+
supported_env = {
|
17 |
+
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
|
18 |
+
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
|
19 |
+
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
|
20 |
+
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
|
21 |
+
}
|
22 |
+
|
23 |
+
supported_env = EasyDict(supported_env)
|
DI-engine/ding/config/example/C51/gym_lunarlander_v2.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='lunarlander_c51',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
collector_env_num=8,
|
9 |
+
evaluator_env_num=8,
|
10 |
+
env_id='LunarLander-v2',
|
11 |
+
n_evaluator_episode=8,
|
12 |
+
stop_value=260,
|
13 |
+
),
|
14 |
+
policy=dict(
|
15 |
+
cuda=False,
|
16 |
+
model=dict(
|
17 |
+
obs_shape=8,
|
18 |
+
action_shape=4,
|
19 |
+
encoder_hidden_size_list=[512, 64],
|
20 |
+
v_min=-30,
|
21 |
+
v_max=30,
|
22 |
+
n_atom=51,
|
23 |
+
),
|
24 |
+
discount_factor=0.99,
|
25 |
+
nstep=3,
|
26 |
+
learn=dict(
|
27 |
+
update_per_collect=10,
|
28 |
+
batch_size=64,
|
29 |
+
learning_rate=0.001,
|
30 |
+
target_update_freq=100,
|
31 |
+
),
|
32 |
+
collect=dict(
|
33 |
+
n_sample=64,
|
34 |
+
unroll_len=1,
|
35 |
+
),
|
36 |
+
other=dict(
|
37 |
+
eps=dict(
|
38 |
+
type='exp',
|
39 |
+
start=0.95,
|
40 |
+
end=0.1,
|
41 |
+
decay=50000,
|
42 |
+
), replay_buffer=dict(replay_buffer_size=100000, )
|
43 |
+
),
|
44 |
+
),
|
45 |
+
wandb_logger=dict(
|
46 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
47 |
+
),
|
48 |
+
)
|
49 |
+
|
50 |
+
cfg = EasyDict(cfg)
|
51 |
+
|
52 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='PongNoFrameskip-v4-C51',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
collector_env_num=8,
|
9 |
+
evaluator_env_num=8,
|
10 |
+
n_evaluator_episode=8,
|
11 |
+
stop_value=30,
|
12 |
+
env_id='PongNoFrameskip-v4',
|
13 |
+
frame_stack=4,
|
14 |
+
env_wrapper='atari_default',
|
15 |
+
),
|
16 |
+
policy=dict(
|
17 |
+
cuda=True,
|
18 |
+
priority=False,
|
19 |
+
model=dict(
|
20 |
+
obs_shape=[4, 84, 84],
|
21 |
+
action_shape=6,
|
22 |
+
encoder_hidden_size_list=[128, 128, 512],
|
23 |
+
v_min=-10,
|
24 |
+
v_max=10,
|
25 |
+
n_atom=51,
|
26 |
+
),
|
27 |
+
nstep=3,
|
28 |
+
discount_factor=0.99,
|
29 |
+
learn=dict(
|
30 |
+
update_per_collect=10,
|
31 |
+
batch_size=32,
|
32 |
+
learning_rate=0.0001,
|
33 |
+
target_update_freq=500,
|
34 |
+
),
|
35 |
+
collect=dict(n_sample=100, ),
|
36 |
+
eval=dict(evaluator=dict(eval_freq=4000, )),
|
37 |
+
other=dict(
|
38 |
+
eps=dict(
|
39 |
+
type='exp',
|
40 |
+
start=1.,
|
41 |
+
end=0.05,
|
42 |
+
decay=250000,
|
43 |
+
),
|
44 |
+
replay_buffer=dict(replay_buffer_size=100000, ),
|
45 |
+
),
|
46 |
+
),
|
47 |
+
wandb_logger=dict(
|
48 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
49 |
+
),
|
50 |
+
)
|
51 |
+
|
52 |
+
cfg = EasyDict(cfg)
|
53 |
+
|
54 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='QbertNoFrameskip-v4-C51',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
collector_env_num=8,
|
9 |
+
evaluator_env_num=8,
|
10 |
+
n_evaluator_episode=8,
|
11 |
+
stop_value=30000,
|
12 |
+
env_id='QbertNoFrameskip-v4',
|
13 |
+
frame_stack=4,
|
14 |
+
env_wrapper='atari_default',
|
15 |
+
),
|
16 |
+
policy=dict(
|
17 |
+
cuda=True,
|
18 |
+
priority=True,
|
19 |
+
model=dict(
|
20 |
+
obs_shape=[4, 84, 84],
|
21 |
+
action_shape=6,
|
22 |
+
encoder_hidden_size_list=[128, 128, 512],
|
23 |
+
v_min=-10,
|
24 |
+
v_max=10,
|
25 |
+
n_atom=51,
|
26 |
+
),
|
27 |
+
nstep=3,
|
28 |
+
discount_factor=0.99,
|
29 |
+
learn=dict(
|
30 |
+
update_per_collect=10,
|
31 |
+
batch_size=32,
|
32 |
+
learning_rate=0.0001,
|
33 |
+
target_update_freq=500,
|
34 |
+
),
|
35 |
+
collect=dict(n_sample=100, ),
|
36 |
+
eval=dict(evaluator=dict(eval_freq=4000, )),
|
37 |
+
other=dict(
|
38 |
+
eps=dict(
|
39 |
+
type='exp',
|
40 |
+
start=1.,
|
41 |
+
end=0.05,
|
42 |
+
decay=1000000,
|
43 |
+
),
|
44 |
+
replay_buffer=dict(replay_buffer_size=400000, ),
|
45 |
+
),
|
46 |
+
),
|
47 |
+
wandb_logger=dict(
|
48 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
49 |
+
),
|
50 |
+
)
|
51 |
+
|
52 |
+
cfg = EasyDict(cfg)
|
53 |
+
|
54 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='SpaceInvadersNoFrameskip-v4-C51',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
collector_env_num=8,
|
9 |
+
evaluator_env_num=8,
|
10 |
+
n_evaluator_episode=8,
|
11 |
+
stop_value=10000000000,
|
12 |
+
env_id='SpaceInvadersNoFrameskip-v4',
|
13 |
+
frame_stack=4,
|
14 |
+
env_wrapper='atari_default',
|
15 |
+
),
|
16 |
+
policy=dict(
|
17 |
+
cuda=True,
|
18 |
+
priority=False,
|
19 |
+
model=dict(
|
20 |
+
obs_shape=[4, 84, 84],
|
21 |
+
action_shape=6,
|
22 |
+
encoder_hidden_size_list=[128, 128, 512],
|
23 |
+
v_min=-10,
|
24 |
+
v_max=10,
|
25 |
+
n_atom=51,
|
26 |
+
),
|
27 |
+
nstep=3,
|
28 |
+
discount_factor=0.99,
|
29 |
+
learn=dict(
|
30 |
+
update_per_collect=10,
|
31 |
+
batch_size=32,
|
32 |
+
learning_rate=0.0001,
|
33 |
+
target_update_freq=500,
|
34 |
+
),
|
35 |
+
collect=dict(n_sample=100, ),
|
36 |
+
eval=dict(evaluator=dict(eval_freq=4000, )),
|
37 |
+
other=dict(
|
38 |
+
eps=dict(
|
39 |
+
type='exp',
|
40 |
+
start=1.,
|
41 |
+
end=0.05,
|
42 |
+
decay=1000000,
|
43 |
+
),
|
44 |
+
replay_buffer=dict(replay_buffer_size=400000, ),
|
45 |
+
),
|
46 |
+
),
|
47 |
+
wandb_logger=dict(
|
48 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
49 |
+
),
|
50 |
+
)
|
51 |
+
|
52 |
+
cfg = EasyDict(cfg)
|
53 |
+
|
54 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DDPG/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
from . import gym_bipedalwalker_v3
|
3 |
+
from . import gym_halfcheetah_v3
|
4 |
+
from . import gym_hopper_v3
|
5 |
+
from . import gym_lunarlandercontinuous_v2
|
6 |
+
from . import gym_pendulum_v1
|
7 |
+
from . import gym_walker2d_v3
|
8 |
+
|
9 |
+
supported_env_cfg = {
|
10 |
+
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
|
11 |
+
gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.cfg,
|
12 |
+
gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.cfg,
|
13 |
+
gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.cfg,
|
14 |
+
gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
|
15 |
+
gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.cfg,
|
16 |
+
}
|
17 |
+
|
18 |
+
supported_env_cfg = EasyDict(supported_env_cfg)
|
19 |
+
|
20 |
+
supported_env = {
|
21 |
+
gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
|
22 |
+
gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.env,
|
23 |
+
gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.env,
|
24 |
+
gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.env,
|
25 |
+
gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
|
26 |
+
gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.env,
|
27 |
+
}
|
28 |
+
|
29 |
+
supported_env = EasyDict(supported_env)
|
DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='Bipedalwalker-v3-DDPG',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='BipedalWalker-v3',
|
9 |
+
collector_env_num=8,
|
10 |
+
evaluator_env_num=5,
|
11 |
+
n_evaluator_episode=5,
|
12 |
+
act_scale=True,
|
13 |
+
rew_clip=True,
|
14 |
+
),
|
15 |
+
policy=dict(
|
16 |
+
cuda=True,
|
17 |
+
random_collect_size=10000,
|
18 |
+
model=dict(
|
19 |
+
obs_shape=24,
|
20 |
+
action_shape=4,
|
21 |
+
twin_critic=False,
|
22 |
+
action_space='regression',
|
23 |
+
actor_head_hidden_size=400,
|
24 |
+
critic_head_hidden_size=400,
|
25 |
+
),
|
26 |
+
learn=dict(
|
27 |
+
update_per_collect=64,
|
28 |
+
batch_size=256,
|
29 |
+
learning_rate_actor=0.0003,
|
30 |
+
learning_rate_critic=0.0003,
|
31 |
+
target_theta=0.005,
|
32 |
+
discount_factor=0.99,
|
33 |
+
learner=dict(hook=dict(log_show_after_iter=1000, ))
|
34 |
+
),
|
35 |
+
collect=dict(n_sample=64, ),
|
36 |
+
other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
|
37 |
+
),
|
38 |
+
wandb_logger=dict(
|
39 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
40 |
+
),
|
41 |
+
)
|
42 |
+
|
43 |
+
cfg = EasyDict(cfg)
|
44 |
+
|
45 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='HalfCheetah-v3-DDPG',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='HalfCheetah-v3',
|
9 |
+
norm_obs=dict(use_norm=False, ),
|
10 |
+
norm_reward=dict(use_norm=False, ),
|
11 |
+
collector_env_num=1,
|
12 |
+
evaluator_env_num=8,
|
13 |
+
n_evaluator_episode=8,
|
14 |
+
stop_value=11000,
|
15 |
+
env_wrapper='mujoco_default',
|
16 |
+
),
|
17 |
+
policy=dict(
|
18 |
+
cuda=True,
|
19 |
+
random_collect_size=25000,
|
20 |
+
model=dict(
|
21 |
+
obs_shape=17,
|
22 |
+
action_shape=6,
|
23 |
+
twin_critic=False,
|
24 |
+
actor_head_hidden_size=256,
|
25 |
+
critic_head_hidden_size=256,
|
26 |
+
action_space='regression',
|
27 |
+
),
|
28 |
+
learn=dict(
|
29 |
+
update_per_collect=1,
|
30 |
+
batch_size=256,
|
31 |
+
learning_rate_actor=1e-3,
|
32 |
+
learning_rate_critic=1e-3,
|
33 |
+
ignore_done=True,
|
34 |
+
target_theta=0.005,
|
35 |
+
discount_factor=0.99,
|
36 |
+
actor_update_freq=1,
|
37 |
+
noise=False,
|
38 |
+
),
|
39 |
+
collect=dict(
|
40 |
+
n_sample=1,
|
41 |
+
unroll_len=1,
|
42 |
+
noise_sigma=0.1,
|
43 |
+
),
|
44 |
+
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
|
45 |
+
),
|
46 |
+
wandb_logger=dict(
|
47 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
cfg = EasyDict(cfg)
|
52 |
+
|
53 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DDPG/gym_hopper_v3.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='Hopper-v3-DDPG',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='Hopper-v3',
|
9 |
+
norm_obs=dict(use_norm=False, ),
|
10 |
+
norm_reward=dict(use_norm=False, ),
|
11 |
+
collector_env_num=1,
|
12 |
+
evaluator_env_num=8,
|
13 |
+
n_evaluator_episode=8,
|
14 |
+
stop_value=6000,
|
15 |
+
env_wrapper='mujoco_default',
|
16 |
+
),
|
17 |
+
policy=dict(
|
18 |
+
cuda=True,
|
19 |
+
random_collect_size=25000,
|
20 |
+
model=dict(
|
21 |
+
obs_shape=11,
|
22 |
+
action_shape=3,
|
23 |
+
twin_critic=False,
|
24 |
+
actor_head_hidden_size=256,
|
25 |
+
critic_head_hidden_size=256,
|
26 |
+
action_space='regression',
|
27 |
+
),
|
28 |
+
learn=dict(
|
29 |
+
update_per_collect=1,
|
30 |
+
batch_size=256,
|
31 |
+
learning_rate_actor=1e-3,
|
32 |
+
learning_rate_critic=1e-3,
|
33 |
+
ignore_done=False,
|
34 |
+
target_theta=0.005,
|
35 |
+
discount_factor=0.99,
|
36 |
+
actor_update_freq=1,
|
37 |
+
noise=False,
|
38 |
+
),
|
39 |
+
collect=dict(
|
40 |
+
n_sample=1,
|
41 |
+
unroll_len=1,
|
42 |
+
noise_sigma=0.1,
|
43 |
+
),
|
44 |
+
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
|
45 |
+
),
|
46 |
+
wandb_logger=dict(
|
47 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
cfg = EasyDict(cfg)
|
52 |
+
|
53 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
from functools import partial
|
3 |
+
import ding.envs.gym_env
|
4 |
+
|
5 |
+
cfg = dict(
|
6 |
+
exp_name='LunarLanderContinuous-V2-DDPG',
|
7 |
+
seed=0,
|
8 |
+
env=dict(
|
9 |
+
env_id='LunarLanderContinuous-v2',
|
10 |
+
collector_env_num=8,
|
11 |
+
evaluator_env_num=8,
|
12 |
+
n_evaluator_episode=8,
|
13 |
+
stop_value=260,
|
14 |
+
act_scale=True,
|
15 |
+
),
|
16 |
+
policy=dict(
|
17 |
+
cuda=True,
|
18 |
+
random_collect_size=0,
|
19 |
+
model=dict(
|
20 |
+
obs_shape=8,
|
21 |
+
action_shape=2,
|
22 |
+
twin_critic=True,
|
23 |
+
action_space='regression',
|
24 |
+
),
|
25 |
+
learn=dict(
|
26 |
+
update_per_collect=2,
|
27 |
+
batch_size=128,
|
28 |
+
learning_rate_actor=0.001,
|
29 |
+
learning_rate_critic=0.001,
|
30 |
+
ignore_done=False, # TODO(pu)
|
31 |
+
# (int) When critic network updates once, how many times will actor network update.
|
32 |
+
# Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
|
33 |
+
# Default 1 for DDPG, 2 for TD3.
|
34 |
+
actor_update_freq=1,
|
35 |
+
# (bool) Whether to add noise on target network's action.
|
36 |
+
# Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
|
37 |
+
# Default True for TD3, False for DDPG.
|
38 |
+
noise=False,
|
39 |
+
noise_sigma=0.1,
|
40 |
+
noise_range=dict(
|
41 |
+
min=-0.5,
|
42 |
+
max=0.5,
|
43 |
+
),
|
44 |
+
),
|
45 |
+
collect=dict(
|
46 |
+
n_sample=48,
|
47 |
+
noise_sigma=0.1,
|
48 |
+
collector=dict(collect_print_freq=1000, ),
|
49 |
+
),
|
50 |
+
eval=dict(evaluator=dict(eval_freq=100, ), ),
|
51 |
+
other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
|
52 |
+
),
|
53 |
+
wandb_logger=dict(
|
54 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
55 |
+
),
|
56 |
+
)
|
57 |
+
|
58 |
+
cfg = EasyDict(cfg)
|
59 |
+
|
60 |
+
env = partial(ding.envs.gym_env.env, continuous=True)
|
DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='Pendulum-v1-DDPG',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='Pendulum-v1',
|
9 |
+
collector_env_num=8,
|
10 |
+
evaluator_env_num=5,
|
11 |
+
n_evaluator_episode=5,
|
12 |
+
stop_value=-250,
|
13 |
+
act_scale=True,
|
14 |
+
),
|
15 |
+
policy=dict(
|
16 |
+
cuda=False,
|
17 |
+
priority=False,
|
18 |
+
random_collect_size=800,
|
19 |
+
model=dict(
|
20 |
+
obs_shape=3,
|
21 |
+
action_shape=1,
|
22 |
+
twin_critic=False,
|
23 |
+
action_space='regression',
|
24 |
+
),
|
25 |
+
learn=dict(
|
26 |
+
update_per_collect=2,
|
27 |
+
batch_size=128,
|
28 |
+
learning_rate_actor=0.001,
|
29 |
+
learning_rate_critic=0.001,
|
30 |
+
ignore_done=True,
|
31 |
+
actor_update_freq=1,
|
32 |
+
noise=False,
|
33 |
+
),
|
34 |
+
collect=dict(
|
35 |
+
n_sample=48,
|
36 |
+
noise_sigma=0.1,
|
37 |
+
collector=dict(collect_print_freq=1000, ),
|
38 |
+
),
|
39 |
+
eval=dict(evaluator=dict(eval_freq=100, )),
|
40 |
+
other=dict(replay_buffer=dict(
|
41 |
+
replay_buffer_size=20000,
|
42 |
+
max_use=16,
|
43 |
+
), ),
|
44 |
+
),
|
45 |
+
wandb_logger=dict(
|
46 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
47 |
+
),
|
48 |
+
)
|
49 |
+
|
50 |
+
cfg = EasyDict(cfg)
|
51 |
+
|
52 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='Walker2d-v3-DDPG',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='Walker2d-v3',
|
9 |
+
norm_obs=dict(use_norm=False, ),
|
10 |
+
norm_reward=dict(use_norm=False, ),
|
11 |
+
collector_env_num=1,
|
12 |
+
evaluator_env_num=8,
|
13 |
+
n_evaluator_episode=8,
|
14 |
+
stop_value=6000,
|
15 |
+
env_wrapper='mujoco_default',
|
16 |
+
),
|
17 |
+
policy=dict(
|
18 |
+
cuda=True,
|
19 |
+
random_collect_size=25000,
|
20 |
+
model=dict(
|
21 |
+
obs_shape=17,
|
22 |
+
action_shape=6,
|
23 |
+
twin_critic=False,
|
24 |
+
actor_head_hidden_size=256,
|
25 |
+
critic_head_hidden_size=256,
|
26 |
+
action_space='regression',
|
27 |
+
),
|
28 |
+
learn=dict(
|
29 |
+
update_per_collect=1,
|
30 |
+
batch_size=256,
|
31 |
+
learning_rate_actor=1e-3,
|
32 |
+
learning_rate_critic=1e-3,
|
33 |
+
ignore_done=False,
|
34 |
+
target_theta=0.005,
|
35 |
+
discount_factor=0.99,
|
36 |
+
actor_update_freq=1,
|
37 |
+
noise=False,
|
38 |
+
),
|
39 |
+
collect=dict(
|
40 |
+
n_sample=1,
|
41 |
+
unroll_len=1,
|
42 |
+
noise_sigma=0.1,
|
43 |
+
),
|
44 |
+
other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
|
45 |
+
),
|
46 |
+
wandb_logger=dict(
|
47 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
cfg = EasyDict(cfg)
|
52 |
+
|
53 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DQN/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
from . import gym_lunarlander_v2
|
3 |
+
from . import gym_pongnoframeskip_v4
|
4 |
+
from . import gym_qbertnoframeskip_v4
|
5 |
+
from . import gym_spaceInvadersnoframeskip_v4
|
6 |
+
|
7 |
+
supported_env_cfg = {
|
8 |
+
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
|
9 |
+
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
|
10 |
+
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
|
11 |
+
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
|
12 |
+
}
|
13 |
+
|
14 |
+
supported_env_cfg = EasyDict(supported_env_cfg)
|
15 |
+
|
16 |
+
supported_env = {
|
17 |
+
gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
|
18 |
+
gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
|
19 |
+
gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
|
20 |
+
gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
|
21 |
+
}
|
22 |
+
|
23 |
+
supported_env = EasyDict(supported_env)
|
DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='LunarLander-v2-DQN',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='LunarLander-v2',
|
9 |
+
collector_env_num=8,
|
10 |
+
evaluator_env_num=8,
|
11 |
+
n_evaluator_episode=8,
|
12 |
+
stop_value=260,
|
13 |
+
),
|
14 |
+
policy=dict(
|
15 |
+
cuda=True,
|
16 |
+
random_collect_size=25000,
|
17 |
+
discount_factor=0.99,
|
18 |
+
nstep=3,
|
19 |
+
learn=dict(
|
20 |
+
update_per_collect=10,
|
21 |
+
batch_size=64,
|
22 |
+
learning_rate=0.001,
|
23 |
+
# Frequency of target network update.
|
24 |
+
target_update_freq=100,
|
25 |
+
),
|
26 |
+
model=dict(
|
27 |
+
obs_shape=8,
|
28 |
+
action_shape=4,
|
29 |
+
encoder_hidden_size_list=[512, 64],
|
30 |
+
# Whether to use dueling head.
|
31 |
+
dueling=True,
|
32 |
+
),
|
33 |
+
collect=dict(
|
34 |
+
n_sample=64,
|
35 |
+
unroll_len=1,
|
36 |
+
),
|
37 |
+
other=dict(
|
38 |
+
eps=dict(
|
39 |
+
type='exp',
|
40 |
+
start=0.95,
|
41 |
+
end=0.1,
|
42 |
+
decay=50000,
|
43 |
+
), replay_buffer=dict(replay_buffer_size=100000, )
|
44 |
+
),
|
45 |
+
),
|
46 |
+
wandb_logger=dict(
|
47 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
cfg = EasyDict(cfg)
|
52 |
+
|
53 |
+
env = ding.envs.gym_env.env
|
DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import ding.envs.gym_env
|
3 |
+
|
4 |
+
cfg = dict(
|
5 |
+
exp_name='PongNoFrameskip-v4-DQN',
|
6 |
+
seed=0,
|
7 |
+
env=dict(
|
8 |
+
env_id='PongNoFrameskip-v4',
|
9 |
+
collector_env_num=8,
|
10 |
+
evaluator_env_num=8,
|
11 |
+
n_evaluator_episode=8,
|
12 |
+
stop_value=30,
|
13 |
+
fram_stack=4,
|
14 |
+
env_wrapper='atari_default',
|
15 |
+
),
|
16 |
+
policy=dict(
|
17 |
+
cuda=True,
|
18 |
+
priority=False,
|
19 |
+
discount_factor=0.99,
|
20 |
+
nstep=3,
|
21 |
+
learn=dict(
|
22 |
+
update_per_collect=10,
|
23 |
+
batch_size=32,
|
24 |
+
learning_rate=0.0001,
|
25 |
+
# Frequency of target network update.
|
26 |
+
target_update_freq=500,
|
27 |
+
),
|
28 |
+
model=dict(
|
29 |
+
obs_shape=[4, 84, 84],
|
30 |
+
action_shape=6,
|
31 |
+
encoder_hidden_size_list=[128, 128, 512],
|
32 |
+
),
|
33 |
+
collect=dict(n_sample=96, ),
|
34 |
+
other=dict(
|
35 |
+
eps=dict(
|
36 |
+
type='exp',
|
37 |
+
start=1.,
|
38 |
+
end=0.05,
|
39 |
+
decay=250000,
|
40 |
+
), replay_buffer=dict(replay_buffer_size=100000, )
|
41 |
+
),
|
42 |
+
),
|
43 |
+
wandb_logger=dict(
|
44 |
+
gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
|
45 |
+
),
|
46 |
+
)
|
47 |
+
|
48 |
+
cfg = EasyDict(cfg)
|
49 |
+
|
50 |
+
env = ding.envs.gym_env.env
|