zjowowen commited on
Commit
079c32c
1 Parent(s): a73e77c

init space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. DI-engine +0 -1
  2. DI-engine/.flake8 +4 -0
  3. DI-engine/.gitignore +1431 -0
  4. DI-engine/.style.yapf +11 -0
  5. DI-engine/CHANGELOG +489 -0
  6. DI-engine/CODE_OF_CONDUCT.md +128 -0
  7. DI-engine/CONTRIBUTING.md +7 -0
  8. DI-engine/LICENSE +202 -0
  9. DI-engine/Makefile +71 -0
  10. DI-engine/README.md +475 -0
  11. DI-engine/cloc.sh +69 -0
  12. DI-engine/codecov.yml +8 -0
  13. DI-engine/conda/conda_build_config.yaml +2 -0
  14. DI-engine/conda/meta.yaml +35 -0
  15. DI-engine/ding/__init__.py +12 -0
  16. DI-engine/ding/bonus/__init__.py +132 -0
  17. DI-engine/ding/bonus/a2c.py +460 -0
  18. DI-engine/ding/bonus/c51.py +459 -0
  19. DI-engine/ding/bonus/common.py +22 -0
  20. DI-engine/ding/bonus/config.py +326 -0
  21. DI-engine/ding/bonus/ddpg.py +456 -0
  22. DI-engine/ding/bonus/dqn.py +460 -0
  23. DI-engine/ding/bonus/model.py +245 -0
  24. DI-engine/ding/bonus/pg.py +453 -0
  25. DI-engine/ding/bonus/ppo_offpolicy.py +471 -0
  26. DI-engine/ding/bonus/ppof.py +509 -0
  27. DI-engine/ding/bonus/sac.py +457 -0
  28. DI-engine/ding/bonus/sql.py +461 -0
  29. DI-engine/ding/bonus/td3.py +455 -0
  30. DI-engine/ding/compatibility.py +9 -0
  31. DI-engine/ding/config/__init__.py +4 -0
  32. DI-engine/ding/config/config.py +579 -0
  33. DI-engine/ding/config/example/A2C/__init__.py +17 -0
  34. DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py +43 -0
  35. DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py +38 -0
  36. DI-engine/ding/config/example/C51/__init__.py +23 -0
  37. DI-engine/ding/config/example/C51/gym_lunarlander_v2.py +52 -0
  38. DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py +54 -0
  39. DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py +54 -0
  40. DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py +54 -0
  41. DI-engine/ding/config/example/DDPG/__init__.py +29 -0
  42. DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py +45 -0
  43. DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py +53 -0
  44. DI-engine/ding/config/example/DDPG/gym_hopper_v3.py +53 -0
  45. DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py +60 -0
  46. DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py +52 -0
  47. DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py +53 -0
  48. DI-engine/ding/config/example/DQN/__init__.py +23 -0
  49. DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py +53 -0
  50. DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py +50 -0
DI-engine DELETED
@@ -1 +0,0 @@
1
- Subproject commit a57bc3024b938c881aaf6511d1fb26296cd98601
 
 
DI-engine/.flake8 ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [flake8]
2
+ ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122,E731
3
+ max-line-length=120
4
+ statistics
DI-engine/.gitignore ADDED
@@ -0,0 +1,1431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### ArchLinuxPackages template
3
+ *.tar
4
+ *.tar.*
5
+ *.jar
6
+ *.exe
7
+ *.msi
8
+ *.zip
9
+ *.tgz
10
+ *.log
11
+ *.log.*
12
+ *.sig
13
+ *.mov
14
+ *.pkl
15
+
16
+ pkg/
17
+ src/
18
+ impala_log/
19
+
20
+ ### CVS template
21
+ /CVS/*
22
+ **/CVS/*
23
+ .cvsignore
24
+ */.cvsignore
25
+
26
+ ### LibreOffice template
27
+ # LibreOffice locks
28
+ .~lock.*#
29
+
30
+ ### CUDA template
31
+ *.i
32
+ *.ii
33
+ *.gpu
34
+ *.ptx
35
+ *.cubin
36
+ *.fatbin
37
+
38
+ ### Eclipse template
39
+ *.bin
40
+ .metadata
41
+ bin/
42
+ tmp/
43
+ *.tmp
44
+ *.bak
45
+ *.swp
46
+ *~.nib
47
+ local.properties
48
+ .settings/
49
+ .loadpath
50
+ .recommenders
51
+
52
+ # External tool builders
53
+ .externalToolBuilders/
54
+
55
+ # Locally stored "Eclipse launch configurations"
56
+ *.launch
57
+
58
+ # PyDev specific (Python IDE for Eclipse)
59
+ *.pydevproject
60
+
61
+ # CDT-specific (C/C++ Development Tooling)
62
+ .cproject
63
+
64
+ # CDT- autotools
65
+ .autotools
66
+
67
+ # Java annotation processor (APT)
68
+ .factorypath
69
+
70
+ # PDT-specific (PHP Development Tools)
71
+ .buildpath
72
+
73
+ # sbteclipse plugin
74
+ .target
75
+
76
+ # Tern plugin
77
+ .tern-project
78
+
79
+ # TeXlipse plugin
80
+ .texlipse
81
+
82
+ # STS (Spring Tool Suite)
83
+ .springBeans
84
+
85
+ # Code Recommenders
86
+ .recommenders/
87
+
88
+ # Annotation Processing
89
+ .apt_generated/
90
+ .apt_generated_test/
91
+
92
+ # Scala IDE specific (Scala & Java development for Eclipse)
93
+ .cache-main
94
+ .scala_dependencies
95
+ .worksheet
96
+
97
+ # Uncomment this line if you wish to ignore the project description file.
98
+ # Typically, this file would be tracked if it contains build/dependency configurations:
99
+ #.project
100
+
101
+ ### SVN template
102
+ .svn/
103
+
104
+ ### Images template
105
+ # JPEG
106
+ *.jpg
107
+ *.jpeg
108
+ *.jpe
109
+ *.jif
110
+ *.jfif
111
+ *.jfi
112
+
113
+ # JPEG 2000
114
+ *.jp2
115
+ *.j2k
116
+ *.jpf
117
+ *.jpx
118
+ *.jpm
119
+ *.mj2
120
+
121
+ # JPEG XR
122
+ *.jxr
123
+ *.hdp
124
+ *.wdp
125
+
126
+ # Graphics Interchange Format
127
+ *.gif
128
+ *.mp4
129
+ *.mpg
130
+
131
+ # RAW
132
+ *.raw
133
+
134
+ # Web P
135
+ *.webp
136
+
137
+ # Portable Network Graphics
138
+ *.png
139
+
140
+ # Animated Portable Network Graphics
141
+ *.apng
142
+
143
+ # Multiple-image Network Graphics
144
+ *.mng
145
+
146
+ # Tagged Image File Format
147
+ *.tiff
148
+ *.tif
149
+
150
+ # Scalable Vector Graphics
151
+ *.svg
152
+ *.svgz
153
+
154
+ # Portable Document Format
155
+ *.pdf
156
+
157
+ # X BitMap
158
+ *.xbm
159
+
160
+ # BMP
161
+ *.bmp
162
+ *.dib
163
+
164
+ # ICO
165
+ *.ico
166
+
167
+ # 3D Images
168
+ *.3dm
169
+ *.max
170
+
171
+ ### Diff template
172
+ *.patch
173
+ *.diff
174
+
175
+ ### JetBrains template
176
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
177
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
178
+
179
+ # User-specific stuff
180
+ .idea/**/workspace.xml
181
+ .idea/**/tasks.xml
182
+ .idea/**/usage.statistics.xml
183
+ .idea/**/dictionaries
184
+ .idea/**/shelf
185
+
186
+ # Generated files
187
+ .idea/**/contentModel.xml
188
+
189
+ # Sensitive or high-churn files
190
+ .idea/**/dataSources/
191
+ .idea/**/dataSources.ids
192
+ .idea/**/dataSources.local.xml
193
+ .idea/**/sqlDataSources.xml
194
+ .idea/**/dynamic.xml
195
+ .idea/**/uiDesigner.xml
196
+ .idea/**/dbnavigator.xml
197
+
198
+ # Gradle
199
+ .idea/**/gradle.xml
200
+ .idea/**/libraries
201
+
202
+ # Gradle and Maven with auto-import
203
+ # When using Gradle or Maven with auto-import, you should exclude module files,
204
+ # since they will be recreated, and may cause churn. Uncomment if using
205
+ # auto-import.
206
+ # .idea/artifacts
207
+ # .idea/compiler.xml
208
+ # .idea/jarRepositories.xml
209
+ # .idea/modules.xml
210
+ # .idea/*.iml
211
+ # .idea/modules
212
+ # *.iml
213
+ # *.ipr
214
+
215
+ # CMake
216
+ cmake-build-*/
217
+
218
+ # Mongo Explorer plugin
219
+ .idea/**/mongoSettings.xml
220
+
221
+ # File-based project format
222
+ *.iws
223
+
224
+ # IntelliJ
225
+ out/
226
+
227
+ # mpeltonen/sbt-idea plugin
228
+ .idea_modules/
229
+
230
+ # JIRA plugin
231
+ atlassian-ide-plugin.xml
232
+
233
+ # Cursive Clojure plugin
234
+ .idea/replstate.xml
235
+
236
+ # Crashlytics plugin (for Android Studio and IntelliJ)
237
+ com_crashlytics_export_strings.xml
238
+ crashlytics.properties
239
+ crashlytics-build.properties
240
+ fabric.properties
241
+
242
+ # Editor-based Rest Client
243
+ .idea/httpRequests
244
+
245
+ # Android studio 3.1+ serialized cache file
246
+ .idea/caches/build_file_checksums.ser
247
+
248
+ ### CodeIgniter template
249
+ */config/development
250
+ */logs/log-*.php
251
+ !*/logs/index.html
252
+ */cache/*
253
+ !*/cache/index.html
254
+ !*/cache/.htaccess
255
+
256
+ user_guide_src/build/*
257
+ user_guide_src/cilexer/build/*
258
+ user_guide_src/cilexer/dist/*
259
+ user_guide_src/cilexer/pycilexer.egg-info/*
260
+
261
+ #codeigniter 3
262
+ application/logs/*
263
+ !application/logs/index.html
264
+ !application/logs/.htaccess
265
+ /vendor/
266
+
267
+ ### Emacs template
268
+ # -*- mode: gitignore; -*-
269
+ *~
270
+ \#*\#
271
+ /.emacs.desktop
272
+ /.emacs.desktop.lock
273
+ *.elc
274
+ auto-save-list
275
+ tramp
276
+ .\#*
277
+
278
+ # Org-mode
279
+ .org-id-locations
280
+ *_archive
281
+
282
+ # flymake-mode
283
+ *_flymake.*
284
+
285
+ # eshell files
286
+ /eshell/history
287
+ /eshell/lastdir
288
+
289
+ # elpa packages
290
+ /elpa/
291
+
292
+ # reftex files
293
+ *.rel
294
+
295
+ # AUCTeX auto folder
296
+ /auto/
297
+
298
+ # cask packages
299
+ .cask/
300
+ dist/
301
+
302
+ # Flycheck
303
+ flycheck_*.el
304
+
305
+ # server auth directory
306
+ /server/
307
+
308
+ # projectiles files
309
+ .projectile
310
+
311
+ # directory configuration
312
+ .dir-locals.el
313
+
314
+ # network security
315
+ /network-security.data
316
+
317
+
318
+ ### Windows template
319
+ # Windows thumbnail cache files
320
+ Thumbs.db
321
+ Thumbs.db:encryptable
322
+ ehthumbs.db
323
+ ehthumbs_vista.db
324
+
325
+ # Dump file
326
+ *.stackdump
327
+
328
+ # Folder config file
329
+ [Dd]esktop.ini
330
+
331
+ # Recycle Bin used on file shares
332
+ $RECYCLE.BIN/
333
+
334
+ # Windows Installer files
335
+ *.cab
336
+ *.msix
337
+ *.msm
338
+ *.msp
339
+
340
+ # Windows shortcuts
341
+ *.lnk
342
+
343
+ ### VisualStudioCode template
344
+ .vscode/*
345
+ !.vscode/settings.json
346
+ !.vscode/tasks.json
347
+ !.vscode/launch.json
348
+ !.vscode/extensions.json
349
+ *.code-workspace
350
+
351
+ # Local History for Visual Studio Code
352
+ .history/
353
+
354
+ ### CMake template
355
+ CMakeLists.txt.user
356
+ CMakeCache.txt
357
+ CMakeFiles
358
+ CMakeScripts
359
+ Testing
360
+ cmake_install.cmake
361
+ install_manifest.txt
362
+ compile_commands.json
363
+ CTestTestfile.cmake
364
+ _deps
365
+
366
+ ### VisualStudio template
367
+ ## Ignore Visual Studio temporary files, build results, and
368
+ ## files generated by popular Visual Studio add-ons.
369
+ ##
370
+ ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
371
+
372
+ # User-specific files
373
+ *.rsuser
374
+ *.suo
375
+ *.user
376
+ *.userosscache
377
+ *.sln.docstates
378
+
379
+ # User-specific files (MonoDevelop/Xamarin Studio)
380
+ *.userprefs
381
+
382
+ # Mono auto generated files
383
+ mono_crash.*
384
+
385
+ # Build results
386
+ [Dd]ebug/
387
+ [Dd]ebugPublic/
388
+ [Rr]elease/
389
+ [Rr]eleases/
390
+ x64/
391
+ x86/
392
+ [Ww][Ii][Nn]32/
393
+ [Aa][Rr][Mm]/
394
+ [Aa][Rr][Mm]64/
395
+ bld/
396
+ [Bb]in/
397
+ [Oo]bj/
398
+ [Ll]og/
399
+ [Ll]ogs/
400
+
401
+ # Visual Studio 2015/2017 cache/options directory
402
+ .vs/
403
+ # Uncomment if you have tasks that create the project's static files in wwwroot
404
+ #wwwroot/
405
+
406
+ # Visual Studio 2017 auto generated files
407
+ Generated\ Files/
408
+
409
+ # MSTest test Results
410
+ [Tt]est[Rr]esult*/
411
+ [Bb]uild[Ll]og.*
412
+
413
+ # NUnit
414
+ *.VisualState.xml
415
+ TestResult.xml
416
+ nunit-*.xml
417
+
418
+ # Build Results of an ATL Project
419
+ [Dd]ebugPS/
420
+ [Rr]eleasePS/
421
+ dlldata.c
422
+
423
+ # Benchmark Results
424
+ BenchmarkDotNet.Artifacts/
425
+
426
+ # .NET Core
427
+ project.lock.json
428
+ project.fragment.lock.json
429
+ artifacts/
430
+
431
+ # ASP.NET Scaffolding
432
+ ScaffoldingReadMe.txt
433
+
434
+ # StyleCop
435
+ StyleCopReport.xml
436
+
437
+ # Files built by Visual Studio
438
+ *_i.c
439
+ *_p.c
440
+ *_h.h
441
+ *.ilk
442
+ *.meta
443
+ *.obj
444
+ *.iobj
445
+ *.pch
446
+ *.pdb
447
+ *.ipdb
448
+ *.pgc
449
+ *.pgd
450
+ *.rsp
451
+ *.sbr
452
+ *.tlb
453
+ *.tli
454
+ *.tlh
455
+ *.tmp_proj
456
+ *_wpftmp.csproj
457
+ *.vspscc
458
+ *.vssscc
459
+ .builds
460
+ *.pidb
461
+ *.svclog
462
+ *.scc
463
+
464
+ # Chutzpah Test files
465
+ _Chutzpah*
466
+
467
+ # Visual C++ cache files
468
+ ipch/
469
+ *.aps
470
+ *.ncb
471
+ *.opendb
472
+ *.opensdf
473
+ *.sdf
474
+ *.cachefile
475
+ *.VC.db
476
+ *.VC.VC.opendb
477
+
478
+ # Visual Studio profiler
479
+ *.psess
480
+ *.vsp
481
+ *.vspx
482
+ *.sap
483
+
484
+ # Visual Studio Trace Files
485
+ *.e2e
486
+
487
+ # TFS 2012 Local Workspace
488
+ $tf/
489
+
490
+ # Guidance Automation Toolkit
491
+ *.gpState
492
+
493
+ # ReSharper is a .NET coding add-in
494
+ _ReSharper*/
495
+ *.[Rr]e[Ss]harper
496
+ *.DotSettings.user
497
+
498
+ # TeamCity is a build add-in
499
+ _TeamCity*
500
+
501
+ # DotCover is a Code Coverage Tool
502
+ *.dotCover
503
+
504
+ # AxoCover is a Code Coverage Tool
505
+ .axoCover/*
506
+ !.axoCover/settings.json
507
+
508
+ # Coverlet is a free, cross platform Code Coverage Tool
509
+ coverage*.json
510
+ coverage*.xml
511
+ coverage*.info
512
+
513
+ # Visual Studio code coverage results
514
+ *.coverage
515
+ *.coveragexml
516
+
517
+ # NCrunch
518
+ _NCrunch_*
519
+ .*crunch*.local.xml
520
+ nCrunchTemp_*
521
+
522
+ # MightyMoose
523
+ *.mm.*
524
+ AutoTest.Net/
525
+
526
+ # Web workbench (sass)
527
+ .sass-cache/
528
+
529
+ # Installshield output folder
530
+ [Ee]xpress/
531
+
532
+ # DocProject is a documentation generator add-in
533
+ DocProject/buildhelp/
534
+ DocProject/Help/*.HxT
535
+ DocProject/Help/*.HxC
536
+ DocProject/Help/*.hhc
537
+ DocProject/Help/*.hhk
538
+ DocProject/Help/*.hhp
539
+ DocProject/Help/Html2
540
+ DocProject/Help/html
541
+
542
+ # Click-Once directory
543
+ publish/
544
+
545
+ # Publish Web Output
546
+ *.[Pp]ublish.xml
547
+ *.azurePubxml
548
+ # Note: Comment the next line if you want to checkin your web deploy settings,
549
+ # but database connection strings (with potential passwords) will be unencrypted
550
+ *.pubxml
551
+ *.publishproj
552
+
553
+ # Microsoft Azure Web App publish settings. Comment the next line if you want to
554
+ # checkin your Azure Web App publish settings, but sensitive information contained
555
+ # in these scripts will be unencrypted
556
+ PublishScripts/
557
+
558
+ # NuGet Packages
559
+ *.nupkg
560
+ # NuGet Symbol Packages
561
+ *.snupkg
562
+ # The packages folder can be ignored because of Package Restore
563
+ **/[Pp]ackages/*
564
+ # except build/, which is used as an MSBuild target.
565
+ !**/[Pp]ackages/build/
566
+ # Uncomment if necessary however generally it will be regenerated when needed
567
+ #!**/[Pp]ackages/repositories.config
568
+ # NuGet v3's project.json files produces more ignorable files
569
+ *.nuget.props
570
+ *.nuget.targets
571
+
572
+ # Microsoft Azure Build Output
573
+ csx/
574
+ *.build.csdef
575
+
576
+ # Microsoft Azure Emulator
577
+ ecf/
578
+ rcf/
579
+
580
+ # Windows Store app package directories and files
581
+ AppPackages/
582
+ BundleArtifacts/
583
+ Package.StoreAssociation.xml
584
+ _pkginfo.txt
585
+ *.appx
586
+ *.appxbundle
587
+ *.appxupload
588
+
589
+ # Visual Studio cache files
590
+ # files ending in .cache can be ignored
591
+ *.[Cc]ache
592
+ # but keep track of directories ending in .cache
593
+ !?*.[Cc]ache/
594
+
595
+ # Others
596
+ ClientBin/
597
+ ~$*
598
+ *.dbmdl
599
+ *.dbproj.schemaview
600
+ *.jfm
601
+ *.pfx
602
+ *.publishsettings
603
+ orleans.codegen.cs
604
+
605
+ # Including strong name files can present a security risk
606
+ # (https://github.com/github/gitignore/pull/2483#issue-259490424)
607
+ #*.snk
608
+
609
+ # Since there are multiple workflows, uncomment next line to ignore bower_components
610
+ # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
611
+ #bower_components/
612
+
613
+ # RIA/Silverlight projects
614
+ Generated_Code/
615
+
616
+ # Backup & report files from converting an old project file
617
+ # to a newer Visual Studio version. Backup files are not needed,
618
+ # because we have git ;-)
619
+ _UpgradeReport_Files/
620
+ Backup*/
621
+ UpgradeLog*.XML
622
+ UpgradeLog*.htm
623
+ ServiceFabricBackup/
624
+ *.rptproj.bak
625
+
626
+ # SQL Server files
627
+ *.mdf
628
+ *.ldf
629
+ *.ndf
630
+
631
+ # Business Intelligence projects
632
+ *.rdl.data
633
+ *.bim.layout
634
+ *.bim_*.settings
635
+ *.rptproj.rsuser
636
+ *- [Bb]ackup.rdl
637
+ *- [Bb]ackup ([0-9]).rdl
638
+ *- [Bb]ackup ([0-9][0-9]).rdl
639
+
640
+ # Microsoft Fakes
641
+ FakesAssemblies/
642
+
643
+ # GhostDoc plugin setting file
644
+ *.GhostDoc.xml
645
+
646
+ # Node.js Tools for Visual Studio
647
+ .ntvs_analysis.dat
648
+ node_modules/
649
+
650
+ # Visual Studio 6 build log
651
+ *.plg
652
+
653
+ # Visual Studio 6 workspace options file
654
+ *.opt
655
+
656
+ # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
657
+ *.vbw
658
+
659
+ # Visual Studio LightSwitch build output
660
+ **/*.HTMLClient/GeneratedArtifacts
661
+ **/*.DesktopClient/GeneratedArtifacts
662
+ **/*.DesktopClient/ModelManifest.xml
663
+ **/*.Server/GeneratedArtifacts
664
+ **/*.Server/ModelManifest.xml
665
+ _Pvt_Extensions
666
+
667
+ # Paket dependency manager
668
+ .paket/paket.exe
669
+ paket-files/
670
+
671
+ # FAKE - F# Make
672
+ .fake/
673
+
674
+ # CodeRush personal settings
675
+ .cr/personal
676
+
677
+ # Python Tools for Visual Studio (PTVS)
678
+ __pycache__/
679
+ *.pyc
680
+
681
+ # Cake - Uncomment if you are using it
682
+ # tools/**
683
+ # !tools/packages.config
684
+
685
+ # Tabs Studio
686
+ *.tss
687
+
688
+ # Telerik's JustMock configuration file
689
+ *.jmconfig
690
+
691
+ # BizTalk build output
692
+ *.btp.cs
693
+ *.btm.cs
694
+ *.odx.cs
695
+ *.xsd.cs
696
+
697
+ # OpenCover UI analysis results
698
+ OpenCover/
699
+
700
+ # Azure Stream Analytics local run output
701
+ ASALocalRun/
702
+
703
+ # MSBuild Binary and Structured Log
704
+ *.binlog
705
+
706
+ # NVidia Nsight GPU debugger configuration file
707
+ *.nvuser
708
+
709
+ # MFractors (Xamarin productivity tool) working folder
710
+ .mfractor/
711
+
712
+ # Local History for Visual Studio
713
+ .localhistory/
714
+
715
+ # BeatPulse healthcheck temp database
716
+ healthchecksdb
717
+
718
+ # Backup folder for Package Reference Convert tool in Visual Studio 2017
719
+ MigrationBackup/
720
+
721
+ # Ionide (cross platform F# VS Code tools) working folder
722
+ .ionide/
723
+
724
+ # Fody - auto-generated XML schema
725
+ FodyWeavers.xsd
726
+
727
+ ### Python template
728
+ # Byte-compiled / optimized / DLL files
729
+ *.py[cod]
730
+ *$py.class
731
+
732
+ # C extensions
733
+ *.so
734
+
735
+ # Distribution / packaging
736
+ .Python
737
+ build/
738
+ develop-eggs/
739
+ downloads/
740
+ eggs/
741
+ .eggs/
742
+ lib/
743
+ lib64/
744
+ parts/
745
+ sdist/
746
+ var/
747
+ wheels/
748
+ share/python-wheels/
749
+ *.egg-info/
750
+ .installed.cfg
751
+ *.egg
752
+ MANIFEST
753
+
754
+ # PyInstaller
755
+ # Usually these files are written by a python script from a template
756
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
757
+ *.manifest
758
+ *.spec
759
+
760
+ # Installer logs
761
+ pip-log.txt
762
+ pip-delete-this-directory.txt
763
+
764
+ # Unit test / coverage reports
765
+ htmlcov/
766
+ .tox/
767
+ .nox/
768
+ .coverage
769
+ .coverage.*
770
+ .cache
771
+ nosetests.xml
772
+ coverage.xml
773
+ *.cover
774
+ *.py,cover
775
+ .hypothesis/
776
+ .pytest_cache/
777
+ cover/
778
+
779
+ # Translations
780
+ *.mo
781
+ *.pot
782
+
783
+ # Django stuff:
784
+ local_settings.py
785
+ db.sqlite3
786
+ db.sqlite3-journal
787
+
788
+ # Flask stuff:
789
+ instance/
790
+ .webassets-cache
791
+
792
+ # Scrapy stuff:
793
+ .scrapy
794
+
795
+ # Sphinx documentation
796
+ docs/_build/
797
+
798
+ # PyBuilder
799
+ .pybuilder/
800
+ target/
801
+
802
+ # Jupyter Notebook
803
+ .ipynb_checkpoints
804
+
805
+ # IPython
806
+ profile_default/
807
+ ipython_config.py
808
+
809
+ # pyenv
810
+ # For a library or package, you might want to ignore these files since the code is
811
+ # intended to run in multiple environments; otherwise, check them in:
812
+ # .python-version
813
+
814
+ # pipenv
815
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
816
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
817
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
818
+ # install all needed dependencies.
819
+ #Pipfile.lock
820
+
821
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
822
+ __pypackages__/
823
+
824
+ # Celery stuff
825
+ celerybeat-schedule
826
+ celerybeat.pid
827
+
828
+ # SageMath parsed files
829
+ *.sage.py
830
+
831
+ # Environments
832
+ .env
833
+ .venv
834
+ venv/
835
+ env.bak/
836
+ venv.bak/
837
+
838
+ # Spyder project settings
839
+ .spyderproject
840
+ .spyproject
841
+
842
+ # Rope project settings
843
+ .ropeproject
844
+
845
+ # mkdocs documentation
846
+ /site
847
+
848
+ # mypy
849
+ .mypy_cache/
850
+ .dmypy.json
851
+ dmypy.json
852
+
853
+ # Pyre type checker
854
+ .pyre/
855
+
856
+ # pytype static type analyzer
857
+ .pytype/
858
+
859
+ # Cython debug symbols
860
+ cython_debug/
861
+
862
+ ### Backup template
863
+ *.gho
864
+ *.ori
865
+ *.orig
866
+
867
+ ### Node template
868
+ # Logs
869
+ logs
870
+ npm-debug.log*
871
+ yarn-debug.log*
872
+ yarn-error.log*
873
+ lerna-debug.log*
874
+
875
+ # Diagnostic reports (https://nodejs.org/api/report.html)
876
+ report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
877
+
878
+ # Runtime data
879
+ pids
880
+ *.pid
881
+ *.seed
882
+ *.pid.lock
883
+
884
+ # Directory for instrumented libs generated by jscoverage/JSCover
885
+ lib-cov
886
+
887
+ # Coverage directory used by tools like istanbul
888
+ coverage
889
+ *.lcov
890
+
891
+ # nyc test coverage
892
+ .nyc_output
893
+
894
+ # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
895
+ .grunt
896
+
897
+ # Bower dependency directory (https://bower.io/)
898
+ bower_components
899
+
900
+ # node-waf configuration
901
+ .lock-wscript
902
+
903
+ # Compiled binary addons (https://nodejs.org/api/addons.html)
904
+ build/Release
905
+
906
+ # Dependency directories
907
+ jspm_packages/
908
+
909
+ # Snowpack dependency directory (https://snowpack.dev/)
910
+ web_modules/
911
+
912
+ # TypeScript cache
913
+ *.tsbuildinfo
914
+
915
+ # Optional npm cache directory
916
+ .npm
917
+
918
+ # Optional eslint cache
919
+ .eslintcache
920
+
921
+ # Microbundle cache
922
+ .rpt2_cache/
923
+ .rts2_cache_cjs/
924
+ .rts2_cache_es/
925
+ .rts2_cache_umd/
926
+
927
+ # Optional REPL history
928
+ .node_repl_history
929
+
930
+ # Output of 'npm pack'
931
+
932
+ # Yarn Integrity file
933
+ .yarn-integrity
934
+
935
+ # dotenv environment variables file
936
+ .env.test
937
+
938
+ # parcel-bundler cache (https://parceljs.org/)
939
+ .parcel-cache
940
+
941
+ # Next.js build output
942
+ .next
943
+ out
944
+
945
+ # Nuxt.js build / generate output
946
+ .nuxt
947
+ dist
948
+
949
+ # Gatsby files
950
+ .cache/
951
+ # Comment in the public line in if your project uses Gatsby and not Next.js
952
+ # https://nextjs.org/blog/next-9-1#public-directory-support
953
+ # public
954
+
955
+ # vuepress build output
956
+ .vuepress/dist
957
+
958
+ # Serverless directories
959
+ .serverless/
960
+
961
+ # FuseBox cache
962
+ .fusebox/
963
+
964
+ # DynamoDB Local files
965
+ .dynamodb/
966
+
967
+ # TernJS port file
968
+ .tern-port
969
+
970
+ # Stores VSCode versions used for testing VSCode extensions
971
+ .vscode-test
972
+
973
+ # yarn v2
974
+ .yarn/cache
975
+ .yarn/unplugged
976
+ .yarn/build-state.yml
977
+ .yarn/install-state.gz
978
+ .pnp.*
979
+
980
+ ### VirtualEnv template
981
+ # Virtualenv
982
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
983
+ [Bb]in
984
+ [Ii]nclude
985
+ [Ll]ib
986
+ [Ll]ib64
987
+ [Ll]ocal
988
+ pyvenv.cfg
989
+ pip-selfcheck.json
990
+
991
+ ### macOS template
992
+ # General
993
+ .DS_Store
994
+ .AppleDouble
995
+ .LSOverride
996
+
997
+ # Icon must end with two \r
998
+ Icon
999
+
1000
+ # Thumbnails
1001
+ ._*
1002
+
1003
+ # Files that might appear in the root of a volume
1004
+ .DocumentRevisions-V100
1005
+ .fseventsd
1006
+ .Spotlight-V100
1007
+ .TemporaryItems
1008
+ .Trashes
1009
+ .VolumeIcon.icns
1010
+ .com.apple.timemachine.donotpresent
1011
+
1012
+ # Directories potentially created on remote AFP share
1013
+ .AppleDB
1014
+ .AppleDesktop
1015
+ Network Trash Folder
1016
+ Temporary Items
1017
+ .apdisk
1018
+
1019
+ ### Go template
1020
+ # Binaries for programs and plugins
1021
+ *.exe~
1022
+ *.dll
1023
+ *.dylib
1024
+
1025
+ # Test binary, built with `go test -c`
1026
+ *.test
1027
+
1028
+ # Output of the go coverage tool, specifically when used with LiteIDE
1029
+ *.out
1030
+
1031
+ # Dependency directories (remove the comment below to include it)
1032
+ # vendor/
1033
+
1034
+ ### C template
1035
+ # Prerequisites
1036
+ *.d
1037
+
1038
+ # Object files
1039
+ *.o
1040
+ *.ko
1041
+ *.elf
1042
+
1043
+ # Linker output
1044
+ *.map
1045
+ *.exp
1046
+
1047
+ # Precompiled Headers
1048
+ *.gch
1049
+
1050
+ # Libraries
1051
+ *.lib
1052
+ *.a
1053
+ *.la
1054
+ *.lo
1055
+
1056
+ # Shared objects (inc. Windows DLLs)
1057
+ *.so.*
1058
+
1059
+ # Executables
1060
+ *.app
1061
+ *.i*86
1062
+ *.x86_64
1063
+ *.hex
1064
+
1065
+ # Debug files
1066
+ *.dSYM/
1067
+ *.su
1068
+ *.idb
1069
+
1070
+ # Kernel Module Compile Results
1071
+ *.mod*
1072
+ *.cmd
1073
+ .tmp_versions/
1074
+ modules.order
1075
+ Module.symvers
1076
+ Mkfile.old
1077
+ dkms.conf
1078
+
1079
+ ### Example user template template
1080
+ ### Example user template
1081
+
1082
+ # IntelliJ project files
1083
+ .idea
1084
+ *.iml
1085
+ gen
1086
+ ### TextMate template
1087
+ *.tmproj
1088
+ *.tmproject
1089
+ tmtags
1090
+
1091
+ ### Anjuta template
1092
+ # Local configuration folder and symbol database
1093
+ /.anjuta/
1094
+ /.anjuta_sym_db.db
1095
+
1096
+ ### XilinxISE template
1097
+ # intermediate build files
1098
+ *.bgn
1099
+ *.bit
1100
+ *.bld
1101
+ *.cmd_log
1102
+ *.drc
1103
+ *.ll
1104
+ *.lso
1105
+ *.msd
1106
+ *.msk
1107
+ *.ncd
1108
+ *.ngc
1109
+ *.ngd
1110
+ *.ngr
1111
+ *.pad
1112
+ *.par
1113
+ *.pcf
1114
+ *.prj
1115
+ *.ptwx
1116
+ *.rbb
1117
+ *.rbd
1118
+ *.stx
1119
+ *.syr
1120
+ *.twr
1121
+ *.twx
1122
+ *.unroutes
1123
+ *.ut
1124
+ *.xpi
1125
+ *.xst
1126
+ *_bitgen.xwbt
1127
+ *_envsettings.html
1128
+ *_map.map
1129
+ *_map.mrp
1130
+ *_map.ngm
1131
+ *_map.xrpt
1132
+ *_ngdbuild.xrpt
1133
+ *_pad.csv
1134
+ *_pad.txt
1135
+ *_par.xrpt
1136
+ *_summary.html
1137
+ *_summary.xml
1138
+ *_usage.xml
1139
+ *_xst.xrpt
1140
+
1141
+ # iMPACT generated files
1142
+ _impactbatch.log
1143
+ impact.xsl
1144
+ impact_impact.xwbt
1145
+ ise_impact.cmd
1146
+ webtalk_impact.xml
1147
+
1148
+ # Core Generator generated files
1149
+ xaw2verilog.log
1150
+
1151
+ # project-wide generated files
1152
+ *.gise
1153
+ par_usage_statistics.html
1154
+ usage_statistics_webtalk.html
1155
+ webtalk.log
1156
+ webtalk_pn.xml
1157
+
1158
+ # generated folders
1159
+ iseconfig/
1160
+ xlnx_auto_0_xdb/
1161
+ xst/
1162
+ _ngo/
1163
+ _xmsgs/
1164
+
1165
+ ### TortoiseGit template
1166
+ # Project-level settings
1167
+ /.tgitconfig
1168
+
1169
+ ### C++ template
1170
+ # Prerequisites
1171
+
1172
+ # Compiled Object files
1173
+ *.slo
1174
+
1175
+ # Precompiled Headers
1176
+
1177
+ # Compiled Dynamic libraries
1178
+
1179
+ # Fortran module files
1180
+ *.mod
1181
+ *.smod
1182
+
1183
+ # Compiled Static libraries
1184
+ *.lai
1185
+
1186
+ # Executables
1187
+
1188
+ ### SublimeText template
1189
+ # Cache files for Sublime Text
1190
+ *.tmlanguage.cache
1191
+ *.tmPreferences.cache
1192
+ *.stTheme.cache
1193
+
1194
+ # Workspace files are user-specific
1195
+ *.sublime-workspace
1196
+
1197
+ # Project files should be checked into the repository, unless a significant
1198
+ # proportion of contributors will probably not be using Sublime Text
1199
+ # *.sublime-project
1200
+
1201
+ # SFTP configuration file
1202
+ sftp-config.json
1203
+ sftp-config-alt*.json
1204
+
1205
+ # Package control specific files
1206
+ Package Control.last-run
1207
+ Package Control.ca-list
1208
+ Package Control.ca-bundle
1209
+ Package Control.system-ca-bundle
1210
+ Package Control.cache/
1211
+ Package Control.ca-certs/
1212
+ Package Control.merged-ca-bundle
1213
+ Package Control.user-ca-bundle
1214
+ oscrypto-ca-bundle.crt
1215
+ bh_unicode_properties.cache
1216
+
1217
+ # Sublime-github package stores a github token in this file
1218
+ # https://packagecontrol.io/packages/sublime-github
1219
+ GitHub.sublime-settings
1220
+
1221
+ ### Vim template
1222
+ # Swap
1223
+ [._]*.s[a-v][a-z]
1224
+ !*.svg # comment out if you don't need vector files
1225
+ [._]*.sw[a-p]
1226
+ [._]s[a-rt-v][a-z]
1227
+ [._]ss[a-gi-z]
1228
+ [._]sw[a-p]
1229
+
1230
+ # Session
1231
+ Session.vim
1232
+ Sessionx.vim
1233
+
1234
+ # Temporary
1235
+ .netrwhist
1236
+ # Auto-generated tag files
1237
+ tags
1238
+ # Persistent undo
1239
+ [._]*.un~
1240
+
1241
+ ### Autotools template
1242
+ # http://www.gnu.org/software/automake
1243
+
1244
+ Makefile.in
1245
+ /ar-lib
1246
+ /mdate-sh
1247
+ /py-compile
1248
+ /test-driver
1249
+ /ylwrap
1250
+ .deps/
1251
+ .dirstamp
1252
+
1253
+ # http://www.gnu.org/software/autoconf
1254
+
1255
+ autom4te.cache
1256
+ /autoscan.log
1257
+ /autoscan-*.log
1258
+ /aclocal.m4
1259
+ /compile
1260
+ /config.guess
1261
+ /config.h.in
1262
+ /config.log
1263
+ /config.status
1264
+ /config.sub
1265
+ /configure
1266
+ /configure.scan
1267
+ /depcomp
1268
+ /install-sh
1269
+ /missing
1270
+ /stamp-h1
1271
+
1272
+ # https://www.gnu.org/software/libtool/
1273
+
1274
+ /ltmain.sh
1275
+
1276
+ # http://www.gnu.org/software/texinfo
1277
+
1278
+ /texinfo.tex
1279
+
1280
+ # http://www.gnu.org/software/m4/
1281
+
1282
+ m4/libtool.m4
1283
+ m4/ltoptions.m4
1284
+ m4/ltsugar.m4
1285
+ m4/ltversion.m4
1286
+ m4/lt~obsolete.m4
1287
+
1288
+ # Generated Makefile
1289
+ # (meta build system like autotools,
1290
+ # can automatically generate from config.status script
1291
+ # (which is called by configure script))
1292
+
1293
+ ### Lua template
1294
+ # Compiled Lua sources
1295
+ luac.out
1296
+
1297
+ # luarocks build files
1298
+ *.src.rock
1299
+ *.tar.gz
1300
+
1301
+ # Object files
1302
+ *.os
1303
+
1304
+ # Precompiled Headers
1305
+
1306
+ # Libraries
1307
+ *.def
1308
+
1309
+ # Shared objects (inc. Windows DLLs)
1310
+
1311
+ # Executables
1312
+
1313
+
1314
+ ### Vagrant template
1315
+ # General
1316
+ .vagrant/
1317
+
1318
+ # Log files (if you are creating logs in debug mode, uncomment this)
1319
+ # *.log
1320
+
1321
+ ### Xcode template
1322
+ # Xcode
1323
+ #
1324
+ # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
1325
+
1326
+ ## User settings
1327
+ xcuserdata/
1328
+
1329
+ ## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
1330
+ *.xcscmblueprint
1331
+ *.xccheckout
1332
+
1333
+ ## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
1334
+ DerivedData/
1335
+ *.moved-aside
1336
+ *.pbxuser
1337
+ !default.pbxuser
1338
+ *.mode1v3
1339
+ !default.mode1v3
1340
+ *.mode2v3
1341
+ !default.mode2v3
1342
+ *.perspectivev3
1343
+ !default.perspectivev3
1344
+
1345
+ ## Gcc Patch
1346
+ /*.gcno
1347
+
1348
+ ### Linux template
1349
+
1350
+ # temporary files which can be created if a process still has a handle open of a deleted file
1351
+ .fuse_hidden*
1352
+
1353
+ # KDE directory preferences
1354
+ .directory
1355
+
1356
+ # Linux trash folder which might appear on any partition or disk
1357
+ .Trash-*
1358
+
1359
+ # .nfs files are created when an open file is removed but is still being accessed
1360
+ .nfs*
1361
+
1362
+ ### GitBook template
1363
+ # Node rules:
1364
+ ## Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
1365
+
1366
+ ## Dependency directory
1367
+ ## Commenting this out is preferred by some people, see
1368
+ ## https://docs.npmjs.com/misc/faq#should-i-check-my-node_modules-folder-into-git
1369
+ node_modules
1370
+
1371
+ # Book build output
1372
+ _book
1373
+
1374
+ # eBook build output
1375
+ *.epub
1376
+ *.mobi
1377
+
1378
+ ### CodeSniffer template
1379
+ # gitignore for the PHP Codesniffer framework
1380
+ # website: https://github.com/squizlabs/PHP_CodeSniffer
1381
+ #
1382
+ # Recommended template: PHP.gitignore
1383
+
1384
+ /wpcs/*
1385
+
1386
+ ### PuTTY template
1387
+ # Private key
1388
+ *.ppk
1389
+ *_pb2.py
1390
+ *.pth
1391
+ *.pth.tar
1392
+ *.pt
1393
+ *.npy
1394
+ __pycache__
1395
+ *.egg-info
1396
+ experiment_config.yaml
1397
+ api-log/
1398
+ log/
1399
+ htmlcov
1400
+ *.lock
1401
+ .coverage*
1402
+ /test_*
1403
+ .python-version
1404
+ /name.txt
1405
+ /summary_log
1406
+ policy_*
1407
+ /data
1408
+ .vscode
1409
+ formatted_*
1410
+ **/exp
1411
+ **/benchmark
1412
+ **/model_zoo
1413
+ *ckpt*
1414
+ log*
1415
+ *.puml.png
1416
+ *.puml.eps
1417
+ *.puml.svg
1418
+ default*
1419
+ events.*
1420
+
1421
+ # DI-engine special key
1422
+ *default_logger.txt
1423
+ *default_tb_logger
1424
+ *evaluate.txt
1425
+ *total_config.py
1426
+ eval_config.py
1427
+ collect_demo_data_config.py
1428
+ !ding/**/*.py
1429
+ events.*
1430
+
1431
+ evogym/*
DI-engine/.style.yapf ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [style]
2
+ # For explanation and more information: https://github.com/google/yapf
3
+ BASED_ON_STYLE=pep8
4
+ DEDENT_CLOSING_BRACKETS=True
5
+ SPLIT_BEFORE_FIRST_ARGUMENT=True
6
+ ALLOW_SPLIT_BEFORE_DICT_VALUE=False
7
+ JOIN_MULTIPLE_LINES=False
8
+ COLUMN_LIMIT=120
9
+ BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True
10
+ BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2
11
+ SPACES_AROUND_POWER_OPERATOR=True
DI-engine/CHANGELOG ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2023.11.06(v0.5.0)
2
+ - env: add tabmwp env (#667)
3
+ - env: polish anytrading env issues (#731)
4
+ - algo: add PromptPG algorithm (#667)
5
+ - algo: add Plan Diffuser algorithm (#700)
6
+ - algo: add new pipeline implementation of IMPALA algorithm (#713)
7
+ - algo: add dropout layers to DQN-style algorithms (#712)
8
+ - feature: add new pipeline agent for sac/ddpg/a2c/ppo and Hugging Face support (#637) (#730) (#737)
9
+ - feature: add more unittest cases for model (#728)
10
+ - feature: add collector logging in new pipeline (#735)
11
+ - fix: logger middleware problems (#715)
12
+ - fix: ppo parallel bug (#709)
13
+ - fix: typo in optimizer_helper.py (#726)
14
+ - fix: mlp dropout if condition bug
15
+ - fix: drex collecting data unittest bugs
16
+ - style: polish env manager/wrapper comments and API doc (#742)
17
+ - style: polish model comments and API doc (#722) (#729) (#734) (#736) (#741)
18
+ - style: polish policy comments and API doc (#732)
19
+ - style: polish rl_utils comments and API doc (#724)
20
+ - style: polish torch_utils comments and API doc (#738)
21
+ - style: update README.md and Colab demo (#733)
22
+ - style: update metaworld docker image
23
+
24
+ 2023.08.23(v0.4.9)
25
+ - env: add cliffwalking env (#677)
26
+ - env: add lunarlander ppo config and example
27
+ - algo: add BCQ offline RL algorithm (#640)
28
+ - algo: add Dreamerv3 model-based RL algorithm (#652)
29
+ - algo: add tensor stream merge network tools (#673)
30
+ - algo: add scatter connection model (#680)
31
+ - algo: refactor Decision Transformer in new pipeline and support img input and discrete output (#693)
32
+ - algo: add three variants of Bilinear classes and a FiLM class (#703)
33
+ - feature: polish offpolicy RL multi-gpu DDP training (#679)
34
+ - feature: add middleware for Ape-X distributed pipeline (#696)
35
+ - feature: add example for evaluating trained DQN (#706)
36
+ - fix: to_ndarray fails to assign dtype for scalars (#708)
37
+ - fix: evaluator return episode_info compatibility bug
38
+ - fix: cql example entry wrong config bug
39
+ - fix: enable_save_figure env interface
40
+ - fix: redundant env info bug in evaluator
41
+ - fix: to_item unittest bug
42
+ - style: polish and simplify requirements (#672)
43
+ - style: add Hugging Face Model Zoo badge (#674)
44
+ - style: add openxlab Model Zoo badge (#675)
45
+ - style: fix py37 macos ci bug and update default pytorch from 1.7.1 to 1.12.1 (#678)
46
+ - style: fix mujoco-py compatibility issue for cython<3 (#711)
47
+ - style: fix type spell error (#704)
48
+ - style: fix pypi release actions ubuntu 18.04 bug
49
+ - style: update contact information (e.g. wechat)
50
+ - style: polish algorithm doc tables
51
+
52
+ 2023.05.25(v0.4.8)
53
+ - env: fix gym hybrid reward dtype bug (#664)
54
+ - env: fix atari env id noframeskip bug (#655)
55
+ - env: fix typo in gym any_trading env (#654)
56
+ - env: update td3bc d4rl config (#659)
57
+ - env: polish bipedalwalker config
58
+ - algo: add EDAC offline RL algorithm (#639)
59
+ - algo: add LN and GN norm_type support in ResBlock (#660)
60
+ - algo: add normal value norm baseline for PPOF (#658)
61
+ - algo: polish last layer init/norm in MLP (#650)
62
+ - algo: polish TD3 monitor variable
63
+ - feature: add MAPPO/MASAC task example (#661)
64
+ - feature: add PPO example for complex env observation (#644)
65
+ - feature: add barrier middleware (#570)
66
+ - fix: abnormal collector log and add record_random_collect option (#662)
67
+ - fix: to_item compatibility bug (#646)
68
+ - fix: trainer dtype transform compatibility bug
69
+ - fix: pettingzoo 1.23.0 compatibility bug
70
+ - fix: ensemble head unittest bug
71
+ - style: fix incompatible gym version bug in Dockerfile.env (#653)
72
+ - style: add more algorithm docs
73
+
74
+ 2023.04.11(v0.4.7)
75
+ - env: add dmc2gym env support and baseline (#451)
76
+ - env: update pettingzoo to the latest version (#597)
77
+ - env: polish icm/rnd+onppo config bugs and add app_door_to_key env (#564)
78
+ - env: add lunarlander continuous TD3/SAC config
79
+ - env: polish lunarlander discrete C51 config
80
+ - algo: add Procedure Cloning (PC) imitation learning algorithm (#514)
81
+ - algo: add Munchausen Reinforcement Learning (MDQN) algorithm (#590)
82
+ - algo: add reward/value norm methods: popart & value rescale & symlog (#605)
83
+ - algo: polish reward model config and training pipeline (#624)
84
+ - algo: add PPOF reward space demo support (#608)
85
+ - algo: add PPOF Atari demo support (#589)
86
+ - algo: polish dqn default config and env examples (#611)
87
+ - algo: polish comment and clean code about SAC
88
+ - feature: add language model (e.g. GPT) training utils (#625)
89
+ - feature: remove policy cfg sub fields requirements (#620)
90
+ - feature: add full wandb support (#579)
91
+ - fix: confusing shallow copy operation about next_obs (#641)
92
+ - fix: unsqueeze action_args in PDQN when shape is 1 (#599)
93
+ - fix: evaluator return_info tensor type bug (#592)
94
+ - fix: deque buffer wrapper PER bug (#586)
95
+ - fix: reward model save method compatibility bug
96
+ - fix: logger assertion and unittest bug
97
+ - fix: bfs test py3.9 compatibility bug
98
+ - fix: zergling collector unittest bug
99
+ - style: add DI-engine torch-rpc p2p communication docker (#628)
100
+ - style: add D4RL docker (#591)
101
+ - style: correct typo in task (#617)
102
+ - style: correct typo in time_helper (#602)
103
+ - style: polish readme and add treetensor example
104
+ - style: update contributing doc
105
+
106
+ 2023.02.16(v0.4.6)
107
+ - env: add metadrive env and related ppo config (#574)
108
+ - env: add acrobot env and related dqn config (#577)
109
+ - env: add carracing in box2d (#575)
110
+ - env: add new gym hybrid viz (#563)
111
+ - env: update cartpole IL config (#578)
112
+ - algo: add BDQ algorithm (#558)
113
+ - algo: add procedure cloning model (#573)
114
+ - feature: add simplified PPOF (PPO × Family) interface (#567) (#568) (#581) (#582)
115
+ - fix: to_device and prev_state bug when using ttorch (#571)
116
+ - fix: py38 and numpy unittest bugs (#565)
117
+ - fix: typo in contrastive_loss.py (#572)
118
+ - fix: dizoo envs pkg installation bugs
119
+ - fix: multi_trainer middleware unittest bug
120
+ - style: add evogym docker (#580)
121
+ - style: fix metaworld docker bug
122
+ - style: fix setuptools high version incompatibility bug
123
+ - style: extend treetensor lowest version
124
+
125
+ 2022.12.13(v0.4.5)
126
+ - env: add beergame supply chain optimization env (#512)
127
+ - env: add env gym_pybullet_drones (#526)
128
+ - env: rename eval reward to episode return (#536)
129
+ - algo: add policy gradient algo implementation (#544)
130
+ - algo: add MADDPG algo implementation (#550)
131
+ - algo: add IMPALA continuous algo implementation (#551)
132
+ - algo: add MADQN algo implementation (#540)
133
+ - feature: add new task IMPALA-type distributed training scheme (#321)
134
+ - feature: add load and save method for replaybuffer (#542)
135
+ - feature: add more DingEnvWrapper example (#525)
136
+ - feature: add evaluator more info viz support (#538)
137
+ - feature: add trackback log for subprocess env manager (#534)
138
+ - fix: halfcheetah td3 config file (#537)
139
+ - fix: mujoco action_clip args compatibility bug (#535)
140
+ - fix: atari a2c config entry bug
141
+ - fix: drex unittest compatibility bug
142
+ - style: add Roadmap issue of DI-engine (#548)
143
+ - style: update related project link and new env doc
144
+
145
+ 2022.10.31(v0.4.4)
146
+ - env: add modified gym-hybrid including moving, sliding and hardmove (#505) (#519)
147
+ - env: add evogym support (#495) (#527)
148
+ - env: add save_replay_gif option (#506)
149
+ - env: adapt minigrid_env and related config to latest MiniGrid v2.0.0 (#500)
150
+ - algo: add pcgrad optimizer (#489)
151
+ - algo: add some features in MLP and ResBlock (#511)
152
+ - algo: delete mcts related modules (#518)
153
+ - feature: add wandb middleware and demo (#488) (#523) (#528)
154
+ - feature: add new properties in Context (#499)
155
+ - feature: add single env policy wrapper for policy deployment
156
+ - feature: add custom model demo and doc
157
+ - fix: build logger args and unittests (#522)
158
+ - fix: total_loss calculation in PDQN (#504)
159
+ - fix: save gif function bug
160
+ - fix: level sample unittest bug
161
+ - style: update contact email address (#503)
162
+ - style: polish env log and resblock name
163
+ - style: add details button in readme
164
+
165
+ 2022.09.23(v0.4.3)
166
+ - env: add rule-based gomoku expert (#465)
167
+ - algo: fix a2c policy batch size bug (#481)
168
+ - algo: enable activation option in collaq attention and mixer
169
+ - algo: minor fix about IBC (#477)
170
+ - feature: add IGM support (#486)
171
+ - feature: add tb logger middleware and demo
172
+ - fix: the type conversion in ding_env_wrapper (#483)
173
+ - fix: di-orchestrator version bug in unittest (#479)
174
+ - fix: data collection errors caused by shallow copies (#475)
175
+ - fix: gym==0.26.0 seed args bug
176
+ - style: add readme tutorial link(environment & algorithm) (#490) (#493)
177
+ - style: adjust location of the default_model method in policy (#453)
178
+
179
+ 2022.09.08(v0.4.2)
180
+ - env: add rocket env (#449)
181
+ - env: updated pettingzoo env and improved related performance (#457)
182
+ - env: add mario env demo (#443)
183
+ - env: add MAPPO multi-agent config (#464)
184
+ - env: add mountain car (discrete action) environment (#452)
185
+ - env: fix multi-agent mujoco gym comaptibility bug
186
+ - env: fix gfootball env save_replay variable init bug
187
+ - algo: add IBC (Implicit Behaviour Cloning) algorithm (#401)
188
+ - algo: add BCO (Behaviour Cloning from Observation) algorithm (#270)
189
+ - algo: add continuous PPOPG algorithm (#414)
190
+ - algo: add PER in CollaQ (#472)
191
+ - algo: add activation option in QMIX and CollaQ
192
+ - feature: update ctx to dataclass (#467)
193
+ - fix: base_env FinalMeta bug about gym 0.25.0-0.25.1
194
+ - fix: config inplace modification bug
195
+ - fix: ding cli no argument problem
196
+ - fix: import errors after running setup.py (jinja2, markupsafe)
197
+ - fix: conda py3.6 and cross platform build bug
198
+ - style: add project state and datetime in log dir (#455)
199
+ - style: polish notes for q-learning model (#427)
200
+ - style: revision to mujoco dockerfile and validation (#474)
201
+ - style: add dockerfile for cityflow env
202
+ - style: polish default output log format
203
+
204
+ 2022.08.12(v0.4.1)
205
+ - env: add gym trading env (#424)
206
+ - env: add board games env (tictactoe, gomuku, chess) (#356)
207
+ - env: add sokoban env (#397) (#429)
208
+ - env: add BC and DQN demo for gfootball (#418) (#423)
209
+ - env: add discrete pendulum env (#395)
210
+ - algo: add STEVE model-based algorithm (#363)
211
+ - algo: add PLR algorithm (#408)
212
+ - algo: plugin ST-DIM in PPO (#379)
213
+ - feature: add final result saving in training pipeline
214
+ - fix: random policy randomness bug
215
+ - fix: action_space seed compalbility bug
216
+ - fix: discard message sent by self in redis mq (#354)
217
+ - fix: remove pace controller (#400)
218
+ - fix: import error in serial_pipeline_trex (#410)
219
+ - fix: unittest hang and fail bug (#413)
220
+ - fix: DREX collect data unittest bug
221
+ - fix: remove unused import cv2
222
+ - fix: ding CLI env/policy option bug
223
+ - style: upgrade Python version from 3.6-3.8 to 3.7-3.9
224
+ - style: upgrade gym version from 0.20.0 to 0.25.0
225
+ - style: upgrade torch version from 1.10.0 to 1.12.0
226
+ - style: upgrade mujoco bin from 2.0.0 to 2.1.0
227
+ - style: add buffer api description (#371)
228
+ - style: polish VAE comments (#404)
229
+ - style: unittest for FQF (#412)
230
+ - style: add metaworld dockerfile (#432)
231
+ - style: remove opencv requirement in default setting
232
+ - style: update long description in setup.py
233
+
234
+ 2022.06.21(v0.4.0)
235
+ - env: add MAPPO/MASAC all configs in SMAC (#310) **(SOTA results in SMAC!!!)**
236
+ - env: add dmc2gym env (#344) (#360)
237
+ - env: remove DI-star requirements of dizoo/smac, use official pysc2 (#302)
238
+ - env: add latest GAIL mujoco config (#298)
239
+ - env: polish procgen env (#311)
240
+ - env: add MBPO ant and humanoid config for mbpo (#314)
241
+ - env: fix slime volley env obs space bug when agent_vs_agent
242
+ - env: fix smac env obs space bug
243
+ - env: fix import path error in lunarlander (#362)
244
+ - algo: add Decision Transformer algorithm (#327) (#364)
245
+ - algo: add on-policy PPG algorithm (#312)
246
+ - algo: add DDPPO & add model-based SAC with lambda-return algorithm (#332)
247
+ - algo: add infoNCE loss and ST-DIM algorithm (#326)
248
+ - algo: add FQF distributional RL algorithm (#274)
249
+ - algo: add continuous BC algorithm (#318)
250
+ - algo: add pure policy gradient PPO algorithm (#382)
251
+ - algo: add SQIL + SAC algorithm (#348)
252
+ - algo: polish NGU and related modules (#283) (#343) (#353)
253
+ - algo: add marl distributional td loss (#331)
254
+ - feature: add new worker middleware (#236)
255
+ - feature: refactor model-based RL pipeline (ding/world_model) (#332)
256
+ - feature: refactor logging system in the whole DI-engine (#316)
257
+ - feature: add env supervisor design (#330)
258
+ - feature: support async reset for envpool env manager (#250)
259
+ - feature: add log videos to tensorboard (#320)
260
+ - feature: refactor impala cnn encoder interface (#378)
261
+ - fix: env save replay bug
262
+ - fix: transformer mask inplace operation bug
263
+ - fix: transtion_with_policy_data bug in SAC and PPG
264
+ - style: add dockerfile for ding:hpc image (#337)
265
+ - style: fix mpire 2.3.5 which handles default processes more elegantly (#306)
266
+ - style: use FORMAT_DIR instead of ./ding (#309)
267
+ - style: update quickstart colab link (#347)
268
+ - style: polish comments in ding/model/common (#315)
269
+ - style: update mujoco docker download path (#386)
270
+ - style: fix protobuf new version compatibility bug
271
+ - style: fix torch1.8.0 torch.div compatibility bug
272
+ - style: update doc links in readme
273
+ - style: add outline in readme and update wechat image
274
+ - style: update head image and refactor docker dir
275
+
276
+ 2022.04.23(v0.3.1)
277
+ - env: polish and standardize dizoo config (#252) (#255) (#249) (#246) (#262) (#261) (#266) (#273) (#263) (#280) (#259) (#286) (#277) (#290) (#289) (#299)
278
+ - env: add GRF academic env and config (#281)
279
+ - env: update env inferface of GRF (#258)
280
+ - env: update D4RL offline RL env and config (#285)
281
+ - env: polish PomdpAtariEnv (#254)
282
+ - algo: DREX algorithm (#218)
283
+ - feature: separate mq and parallel modules, add redis (#247)
284
+ - feature: rename env variables; fix attach_to parameter (#244)
285
+ - feature: env implementation check (#275)
286
+ - feature: adjust and set the max column number of tabulate in log (#296)
287
+ - feature: add drop_extra option for sample collect
288
+ - feature: speed up GTrXL forward method + GRU unittest (#253) (#292)
289
+ - fix: add act_scale in DingEnvWrapper; fix envpool env manager (#245)
290
+ - fix: auto_reset=False and env_ref bug in env manager (#248)
291
+ - fix: data type and deepcopy bug in RND (#288)
292
+ - fix: share_memory bug and multi_mujoco env (#279)
293
+ - fix: some bugs in GTrXL (#276)
294
+ - fix: update gym_vector_env_manager and add more unittest (#241)
295
+ - fix: mdpolicy random collect bug (#293)
296
+ - fix: gym.wrapper save video replay bug
297
+ - fix: collect abnormal step format bug and add unittest
298
+ - test: add buffer benchmark & socket test (#284)
299
+ - style: upgrade mpire (#251)
300
+ - style: add GRF(google research football) docker (#256)
301
+ - style: update policy and gail comment
302
+
303
+ 2022.03.24(v0.3.0)
304
+ - env: add bitfilp HER DQN benchmark (#192) (#193) (#197)
305
+ - env: slime volley league training demo (#229)
306
+ - algo: Gated TransformXL (GTrXL) algorithm (#136)
307
+ - algo: TD3 + VAE(HyAR) latent action algorithm (#152)
308
+ - algo: stochastic dueling network (#234)
309
+ - algo: use log prob instead of using prob in ACER (#186)
310
+ - feature: support envpool env manager (#228)
311
+ - feature: add league main and other improvements in new framework (#177) (#214)
312
+ - feature: add pace controller middleware in new framework (#198)
313
+ - feature: add auto recover option in new framework (#242)
314
+ - feature: add k8s parser in new framework (#243)
315
+ - feature: support async event handler and logger (#213)
316
+ - feautre: add grad norm calculator (#205)
317
+ - feautre: add gym vector env manager (#147)
318
+ - feautre: add train_iter and env_step in serial pipeline (#212)
319
+ - feautre: add rich logger handler (#219) (#223) (#232)
320
+ - feature: add naive lr_scheduler demo
321
+ - refactor: new BaseEnv and DingEnvWrapper (#171) (#231) (#240)
322
+ - polish: MAPPO and MASAC smac config (#209) (#239)
323
+ - polish: QMIX smac config (#175)
324
+ - polish: R2D2 atari config (#181)
325
+ - polish: A2C atari config (#189)
326
+ - polish: GAIL box2d and mujoco config (#188)
327
+ - polish: ACER atari config (#180)
328
+ - polish: SQIL atari config (#230)
329
+ - polish: TREX atari/mujoco config
330
+ - polish: IMPALA atari config
331
+ - polish: MBPO/D4PG mujoco config
332
+ - fix: random_collect compatible to episode collector (#190)
333
+ - fix: remove default n_sample/n_episode value in policy config (#185)
334
+ - fix: PDQN model bug on gpu device (#220)
335
+ - fix: TREX algorithm CLI bug (#182)
336
+ - fix: DQfD JE computation bug and move to AdamW optimizer (#191)
337
+ - fix: pytest problem for parallel middleware (#211)
338
+ - fix: mujoco numpy compatibility bug
339
+ - fix: markupsafe 2.1.0 bug
340
+ - fix: framework parallel module network emit bug
341
+ - fix: mpire bug and disable algotest in py3.8
342
+ - fix: lunarlander env import and env_id bug
343
+ - fix: icm unittest repeat name bug
344
+ - fix: buffer thruput close bug
345
+ - test: resnet unittest (#199)
346
+ - test: SAC/SQN unittest (#207)
347
+ - test: CQL/R2D3/GAIL unittest (#201)
348
+ - test: NGU td unittest (#210)
349
+ - test: model wrapper unittest (#215)
350
+ - test: MAQAC model unittest (#226)
351
+ - style: add doc docker (#221)
352
+
353
+ 2022.01.01(v0.2.3)
354
+ - env: add multi-agent mujoco env (#146)
355
+ - env: add delay reward mujoco env (#145)
356
+ - env: fix port conflict in gym_soccer (#139)
357
+ - algo: MASAC algorithm (#112)
358
+ - algo: TREX algorithm (#119) (#144)
359
+ - algo: H-PPO hybrid action space algorithm (#140)
360
+ - algo: residual link in R2D2 (#150)
361
+ - algo: gumbel softmax (#169)
362
+ - algo: move actor_head_type to action_space field
363
+ - feature: new main pipeline and async/parallel framework (#142) (#166) (#168)
364
+ - feature: refactor buffer, separate algorithm and storage (#129)
365
+ - feature: cli in new pipeline(ditask) (#160)
366
+ - feature: add multiprocess tblogger, fix circular reference problem (#156)
367
+ - feature: add multiple seed cli
368
+ - feature: polish eps_greedy_multinomial_sample in model_wrapper (#154)
369
+ - fix: R2D3 abs priority problem (#158) (#161)
370
+ - fix: multi-discrete action space policies random action bug (#167)
371
+ - fix: doc generate bug with enum_tools (#155)
372
+ - style: more comments about R2D2 (#149)
373
+ - style: add doc about how to migrate a new env
374
+ - style: add doc about env tutorial in dizoo
375
+ - style: add conda auto release (#148)
376
+ - style: udpate zh doc link
377
+ - style: update kaggle tutorial link
378
+
379
+ 2021.12.03(v0.2.2)
380
+ - env: apple key to door treasure env (#128)
381
+ - env: add bsuite memory benchmark (#138)
382
+ - env: polish atari impala config
383
+ - algo: Guided Cost IRL algorithm (#57)
384
+ - algo: ICM exploration algorithm (#41)
385
+ - algo: MP-DQN hybrid action space algorithm (#131)
386
+ - algo: add loss statistics and polish r2d3 pong config (#126)
387
+ - feautre: add renew env mechanism in env manager and update timeout mechanism (#127) (#134)
388
+ - fix: async subprocess env manager reset bug (#137)
389
+ - fix: keepdims name bug in model wrapper
390
+ - fix: on-policy ppo value norm bug
391
+ - fix: GAE and RND unittest bug
392
+ - fix: hidden state wrapper h tensor compatiblity
393
+ - fix: naive buffer auto config create bug
394
+ - style: add supporters list
395
+
396
+ 2021.11.22(v0.2.1)
397
+ - env: gym-hybrid env (#86)
398
+ - env: gym-soccer (HFO) env (#94)
399
+ - env: Go-Bigger env baseline (#95)
400
+ - env: add the bipedalwalker config of sac and ppo (#121)
401
+ - algo: DQfD Imitation Learning algorithm (#48) (#98)
402
+ - algo: TD3BC offline RL algorithm (#88)
403
+ - algo: MBPO model-based RL algorithm (#113)
404
+ - algo: PADDPG hybrid action space algorithm (#109)
405
+ - algo: PDQN hybrid action space algorithm (#118)
406
+ - algo: fix R2D2 bugs and produce benchmark, add naive NGU (#40)
407
+ - algo: self-play training demo in slime_volley env (#23)
408
+ - algo: add example of GAIL entry + config for mujoco (#114)
409
+ - feature: enable arbitrary policy num in serial sample collector
410
+ - feautre: add torch DataParallel for single machine multi-GPU
411
+ - feature: add registry force_overwrite argument
412
+ - feature: add naive buffer periodic thruput seconds argument
413
+ - test: add pure docker setting test (#103)
414
+ - test: add unittest for dataset and evaluator (#107)
415
+ - test: add unittest for on-policy algorithm (#92)
416
+ - test: add unittest for ppo and td (MARL case) (#89)
417
+ - test: polish collector benchmark test
418
+ - fix: target model wrapper hard reset bug
419
+ - fix: fix learn state_dict target model bug
420
+ - fix: ppo bugs and update atari ppo offpolicy config (#108)
421
+ - fix: pyyaml version bug (#99)
422
+ - fix: small fix on bsuite environment (#117)
423
+ - fix: discrete cql unittest bug
424
+ - fix: release workflow bug
425
+ - fix: base policy model state_dict overlap bug
426
+ - fix: remove on_policy option in dizoo config and entry
427
+ - fix: remove torch in env
428
+ - style: gym version > 0.20.0
429
+ - style: torch version >= 1.1.0, <= 1.10.0
430
+ - style: ale-py == 0.7.0
431
+
432
+ 2021.9.30(v0.2.0)
433
+ - env: overcooked env (#20)
434
+ - env: procgen env (#26)
435
+ - env: modified predator env (#30)
436
+ - env: d4rl env (#37)
437
+ - env: imagenet dataset (#27)
438
+ - env: bsuite env (#58)
439
+ - env: move atari_py to ale-py
440
+ - algo: SQIL algorithm (#25) (#44)
441
+ - algo: CQL algorithm (discrete/continuous) (#37) (#68)
442
+ - algo: MAPPO algorithm (#62)
443
+ - algo: WQMIX algorithm (#24)
444
+ - algo: D4PG algorithm (#76)
445
+ - algo: update multi discrete policy(dqn, ppo, rainbow) (#51) (#72)
446
+ - feature: image classification training pipeline (#27)
447
+ - feature: add force_reproducibility option in subprocess env manager
448
+ - feature: add/delete/restart replicas via cli for k8s
449
+ - feautre: add league metric (trueskill and elo) (#22)
450
+ - feature: add tb in naive buffer and modify tb in advanced buffer (#39)
451
+ - feature: add k8s launcher and di-orchestrator launcher, add related unittest (#45) (#49)
452
+ - feature: add hyper-parameter scheduler module (#38)
453
+ - feautre: add plot function (#59)
454
+ - fix: acer bug and update atari result (#21)
455
+ - fix: mappo nan bug and dict obs cannot unsqueeze bug (#54)
456
+ - fix: r2d2 hidden state and obs arange bug (#36) (#52)
457
+ - fix: ppo bug when use dual_clip and adv > 0
458
+ - fix: qmix double_q hidden state bug
459
+ - fix: spawn context problem in interaction unittest (#69)
460
+ - fix: formatted config no eval bug (#53)
461
+ - fix: the catch statments that will never succeed and system proxy bug (#71) (#79)
462
+ - fix: lunarlander config
463
+ - fix: c51 head dimension mismatch bug
464
+ - fix: mujoco config typo bug
465
+ - fix: ppg atari config bug
466
+ - fix: max use and priority update special branch bug in advanced_buffer
467
+ - style: add docker deploy in github workflow (#70) (#78) (#80)
468
+ - style: support PyTorch 1.9.0
469
+ - style: add algo/env list in README
470
+ - style: rename advanced_buffer register name to advanced
471
+
472
+
473
+ 2021.8.3(v0.1.1)
474
+ - env: selfplay/league demo (#12)
475
+ - env: pybullet env (#16)
476
+ - env: minigrid env (#13)
477
+ - env: atari enduro config (#11)
478
+ - algo: on policy PPO (#9)
479
+ - algo: ACER algorithm (#14)
480
+ - feature: polish experiment directory structure (#10)
481
+ - refactor: split doc to new repo (#4)
482
+ - fix: atari env info action space bug
483
+ - fix: env manager retry wrapper raise exception info bug
484
+ - fix: dist entry disable-flask-log typo
485
+ - style: codestyle optimization by lgtm (#7)
486
+ - style: code/comment statistics badge
487
+ - style: github CI workflow
488
+
489
+ 2021.7.8(v0.1.0)
DI-engine/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ opendilab.contact@gmail.com.
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
DI-engine/CONTRIBUTING.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)
2
+
3
+ [GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html)
4
+
5
+ - [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
6
+ - [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
7
+ - [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review)
DI-engine/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2017 Google Inc.
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
DI-engine/Makefile ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CI ?=
2
+
3
+ # Directory variables
4
+ DING_DIR ?= ./ding
5
+ DIZOO_DIR ?= ./dizoo
6
+ RANGE_DIR ?=
7
+ TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
8
+ COV_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
9
+ FORMAT_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
10
+ PLATFORM_TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR}/entry/tests/test_serial_entry.py ${DING_DIR}/entry/tests/test_serial_entry_onpolicy.py)
11
+
12
+ # Workers command
13
+ WORKERS ?= 2
14
+ WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,)
15
+
16
+ # Duration command
17
+ DURATIONS ?= 10
18
+ DURATIONS_COMMAND := $(if ${DURATIONS},--durations=${DURATIONS},)
19
+
20
+ docs:
21
+ $(MAKE) -C ${DING_DIR}/docs html
22
+
23
+ unittest:
24
+ pytest ${TEST_DIR} \
25
+ --cov-report=xml \
26
+ --cov-report term-missing \
27
+ --cov=${COV_DIR} \
28
+ ${DURATIONS_COMMAND} \
29
+ ${WORKERS_COMMAND} \
30
+ -sv -m unittest \
31
+
32
+ algotest:
33
+ pytest ${TEST_DIR} \
34
+ ${DURATIONS_COMMAND} \
35
+ -sv -m algotest
36
+
37
+ cudatest:
38
+ pytest ${TEST_DIR} \
39
+ -sv -m cudatest
40
+
41
+ envpooltest:
42
+ pytest ${TEST_DIR} \
43
+ -sv -m envpooltest
44
+
45
+ dockertest:
46
+ ${DING_DIR}/scripts/docker-test-entry.sh
47
+
48
+ platformtest:
49
+ pytest ${TEST_DIR} \
50
+ --cov-report term-missing \
51
+ --cov=${COV_DIR} \
52
+ ${WORKERS_COMMAND} \
53
+ -sv -m platformtest
54
+
55
+ benchmark:
56
+ pytest ${TEST_DIR} \
57
+ --durations=0 \
58
+ -sv -m benchmark
59
+
60
+ test: unittest # just for compatibility, can be changed later
61
+
62
+ cpu_test: unittest algotest benchmark
63
+
64
+ all_test: unittest algotest cudatest benchmark
65
+
66
+ format:
67
+ yapf --in-place --recursive -p --verbose --style .style.yapf ${FORMAT_DIR}
68
+ format_test:
69
+ bash format.sh ${FORMAT_DIR} --test
70
+ flake_check:
71
+ flake8 ${FORMAT_DIR}
DI-engine/README.md ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <a href="https://di-engine-docs.readthedocs.io/en/latest/"><img width="1000px" height="auto" src="https://github.com/opendilab/DI-engine-docs/blob/main/source/images/head_image.png"></a>
3
+ </div>
4
+
5
+ ---
6
+
7
+ [![Twitter](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Ftwitter.com%2Fopendilab)](https://twitter.com/opendilab)
8
+ [![PyPI](https://img.shields.io/pypi/v/DI-engine)](https://pypi.org/project/DI-engine/)
9
+ ![Conda](https://anaconda.org/opendilab/di-engine/badges/version.svg)
10
+ ![Conda update](https://anaconda.org/opendilab/di-engine/badges/latest_release_date.svg)
11
+ ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/DI-engine)
12
+ ![PyTorch Version](https://img.shields.io/badge/dynamic/json?color=blue&label=pytorch&query=%24.pytorchVersion&url=https%3A%2F%2Fgist.githubusercontent.com/PaParaZz1/54c5c44eeb94734e276b2ed5770eba8d/raw/85b94a54933a9369f8843cc2cea3546152a75661/badges.json)
13
+
14
+ ![Loc](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/3690cccd811e4c5f771075c2f785c7bb/raw/loc.json)
15
+ ![Comments](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/3690cccd811e4c5f771075c2f785c7bb/raw/comments.json)
16
+
17
+ ![Style](https://github.com/opendilab/DI-engine/actions/workflows/style.yml/badge.svg)
18
+ [![Read en Docs](https://github.com/opendilab/DI-engine/actions/workflows/doc.yml/badge.svg)](https://di-engine-docs.readthedocs.io/en/latest)
19
+ [![Read zh_CN Docs](https://img.shields.io/readthedocs/di-engine-docs?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://di-engine-docs.readthedocs.io/zh_CN/latest)
20
+ ![Unittest](https://github.com/opendilab/DI-engine/actions/workflows/unit_test.yml/badge.svg)
21
+ ![Algotest](https://github.com/opendilab/DI-engine/actions/workflows/algo_test.yml/badge.svg)
22
+ ![deploy](https://github.com/opendilab/DI-engine/actions/workflows/deploy.yml/badge.svg)
23
+ [![codecov](https://codecov.io/gh/opendilab/DI-engine/branch/main/graph/badge.svg?token=B0Q15JI301)](https://codecov.io/gh/opendilab/DI-engine)
24
+
25
+
26
+
27
+ ![GitHub Org's stars](https://img.shields.io/github/stars/opendilab)
28
+ [![GitHub stars](https://img.shields.io/github/stars/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/stargazers)
29
+ [![GitHub forks](https://img.shields.io/github/forks/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/network)
30
+ ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/opendilab/DI-engine)
31
+ [![GitHub issues](https://img.shields.io/github/issues/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/issues)
32
+ [![GitHub pulls](https://img.shields.io/github/issues-pr/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/pulls)
33
+ [![Contributors](https://img.shields.io/github/contributors/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/graphs/contributors)
34
+ [![GitHub license](https://img.shields.io/github/license/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/blob/master/LICENSE)
35
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/OpenDILabCommunity)
36
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models?search=opendilab)
37
+
38
+ Updated on 2023.12.05 DI-engine-v0.5.0
39
+
40
+
41
+ ## Introduction to DI-engine
42
+ [Documentation](https://di-engine-docs.readthedocs.io/en/latest/) | [中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/) | [Tutorials](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/index.html) | [Feature](#feature) | [Task & Middleware](https://di-engine-docs.readthedocs.io/en/latest/03_system/index.html) | [TreeTensor](#general-data-container-treetensor) | [Roadmap](https://github.com/opendilab/DI-engine/issues/548)
43
+
44
+ **DI-engine** is a generalized decision intelligence engine for PyTorch and JAX.
45
+
46
+ It provides **python-first** and **asynchronous-native** task and middleware abstractions, and modularly integrates several of the most important decision-making concepts: Env, Policy and Model. Based on the above mechanisms, DI-engine supports **various [deep reinforcement learning](https://di-engine-docs.readthedocs.io/en/latest/10_concepts/index.html) algorithms** with superior performance, high efficiency, well-organized [documentation](https://di-engine-docs.readthedocs.io/en/latest/) and [unittest](https://github.com/opendilab/DI-engine/actions):
47
+
48
+ - Most basic DRL algorithms: such as DQN, Rainbow, PPO, TD3, SAC, R2D2, IMPALA
49
+ - Multi-agent RL algorithms: such as QMIX, WQMIX, MAPPO, HAPPO, ACE
50
+ - Imitation learning algorithms (BC/IRL/GAIL): such as GAIL, SQIL, Guided Cost Learning, Implicit BC
51
+ - Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
52
+ - Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3, MuZero
53
+ - Exploration algorithms: HER, RND, ICM, NGU
54
+ - LLM + RL Algorithms: PPO-max, DPO, MPDPO
55
+ - Other algorithms: such as PER, PLR, PCGrad
56
+
57
+ **DI-engine** aims to **standardize different Decision Intelligence environments and applications**, supporting both academic research and prototype applications. Various training pipelines and customized decision AI applications are also supported:
58
+
59
+ <details open>
60
+ <summary>(Click to Collapse)</summary>
61
+
62
+ - Traditional academic environments
63
+ - [DI-zoo](https://github.com/opendilab/DI-engine#environment-versatility): various decision intelligence demonstrations and benchmark environments with DI-engine.
64
+ - Tutorial courses
65
+ - [PPOxFamily](https://github.com/opendilab/PPOxFamily): PPO x Family DRL Tutorial Course
66
+ - Real world decision AI applications
67
+ - [DI-star](https://github.com/opendilab/DI-star): Decision AI in StarCraftII
68
+ - [DI-drive](https://github.com/opendilab/DI-drive): Auto-driving platform
69
+ - [DI-sheep](https://github.com/opendilab/DI-sheep): Decision AI in 3 Tiles Game
70
+ - [DI-smartcross](https://github.com/opendilab/DI-smartcross): Decision AI in Traffic Light Control
71
+ - [DI-bioseq](https://github.com/opendilab/DI-bioseq): Decision AI in Biological Sequence Prediction and Searching
72
+ - [DI-1024](https://github.com/opendilab/DI-1024): Deep Reinforcement Learning + 1024 Game
73
+ - Research paper
74
+ - [InterFuser](https://github.com/opendilab/InterFuser): [CoRL 2022] Safety-Enhanced Autonomous Driving Using Interpretable Sensor Fusion Transformer
75
+ - [ACE](https://github.com/opendilab/ACE): [AAAI 2023] ACE: Cooperative Multi-agent Q-learning with Bidirectional Action-Dependency
76
+ - [GoBigger](https://github.com/opendilab/GoBigger): [ICLR 2023] Multi-Agent Decision Intelligence Environment
77
+ - [DOS](https://github.com/opendilab/DOS): [CVPR 2023] ReasonNet: End-to-End Driving with Temporal and Global Reasoning
78
+ - [LightZero](https://github.com/opendilab/LightZero): [NeurIPS 2023 Spotlight] A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
79
+ - [SO2](https://github.com/opendilab/SO2): [AAAI 2024] A Perspective of Q-value Estimation on Offline-to-Online Reinforcement Learning
80
+ - [LMDrive](https://github.com/opendilab/LMDrive): LMDrive: Closed-Loop End-to-End Driving with Large Language Models
81
+ - Docs and Tutorials
82
+ - [DI-engine-docs](https://github.com/opendilab/DI-engine-docs): Tutorials, best practice and the API reference.
83
+ - [awesome-model-based-RL](https://github.com/opendilab/awesome-model-based-RL): A curated list of awesome Model-Based RL resources
84
+ - [awesome-exploration-RL](https://github.com/opendilab/awesome-exploration-rl): A curated list of awesome exploration RL resources
85
+ - [awesome-decision-transformer](https://github.com/opendilab/awesome-decision-transformer): A curated list of Decision Transformer resources
86
+ - [awesome-RLHF](https://github.com/opendilab/awesome-RLHF): A curated list of reinforcement learning with human feedback resources
87
+ - [awesome-multi-modal-reinforcement-learning](https://github.com/opendilab/awesome-multi-modal-reinforcement-learning): A curated list of Multi-Modal Reinforcement Learning resources
88
+ - [awesome-AI-based-protein-design](https://github.com/opendilab/awesome-AI-based-protein-design): a collection of research papers for AI-based protein design
89
+ - [awesome-diffusion-model-in-rl](https://github.com/opendilab/awesome-diffusion-model-in-rl): A curated list of Diffusion Model in RL resources
90
+ - [awesome-end-to-end-autonomous-driving](https://github.com/opendilab/awesome-end-to-end-autonomous-driving): A curated list of awesome End-to-End Autonomous Driving resources
91
+ - [awesome-driving-behavior-prediction](https://github.com/opendilab/awesome-driving-behavior-prediction): A collection of research papers for Driving Behavior Prediction
92
+ </details>
93
+
94
+ On the low-level end, DI-engine comes with a set of highly re-usable modules, including [RL optimization functions](https://github.com/opendilab/DI-engine/tree/main/ding/rl_utils), [PyTorch utilities](https://github.com/opendilab/DI-engine/tree/main/ding/torch_utils) and [auxiliary tools](https://github.com/opendilab/DI-engine/tree/main/ding/utils).
95
+
96
+ BTW, **DI-engine** also has some special **system optimization and design** for efficient and robust large-scale RL training:
97
+
98
+ <details close>
99
+ <summary>(Click for Details)</summary>
100
+
101
+ - [treevalue](https://github.com/opendilab/treevalue): Tree-nested data structure
102
+ - [DI-treetensor](https://github.com/opendilab/DI-treetensor): Tree-nested PyTorch tensor Lib
103
+ - [DI-toolkit](https://github.com/opendilab/DI-toolkit): A simple toolkit package for decision intelligence
104
+ - [DI-orchestrator](https://github.com/opendilab/DI-orchestrator): RL Kubernetes Custom Resource and Operator Lib
105
+ - [DI-hpc](https://github.com/opendilab/DI-hpc): RL HPC OP Lib
106
+ - [DI-store](https://github.com/opendilab/DI-store): RL Object Store
107
+ </details>
108
+
109
+ Have fun with exploration and exploitation.
110
+
111
+ ## Outline
112
+
113
+ - [Introduction to DI-engine](#introduction-to-di-engine)
114
+ - [Outline](#outline)
115
+ - [Installation](#installation)
116
+ - [Quick Start](#quick-start)
117
+ - [Feature](#feature)
118
+ - [Algorithm Versatility](#algorithm-versatility)
119
+ - [Environment Versatility](#environment-versatility)
120
+ - [General Data Container: TreeTensor](#general-data-container-treetensor)
121
+ - [Feedback and Contribution](#feedback-and-contribution)
122
+ - [Supporters](#supporters)
123
+ - [↳ Stargazers](#-stargazers)
124
+ - [↳ Forkers](#-forkers)
125
+ - [Citation](#citation)
126
+ - [License](#license)
127
+
128
+ ## Installation
129
+
130
+ You can simply install DI-engine from PyPI with the following command:
131
+ ```bash
132
+ pip install DI-engine
133
+ ```
134
+
135
+ If you use Anaconda or Miniconda, you can install DI-engine from conda-forge through the following command:
136
+ ```bash
137
+ conda install -c opendilab di-engine
138
+ ```
139
+
140
+ For more information about installation, you can refer to [installation](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/installation.html).
141
+
142
+ And our dockerhub repo can be found [here](https://hub.docker.com/repository/docker/opendilab/ding),we prepare `base image` and `env image` with common RL environments.
143
+
144
+ <details close>
145
+ <summary>(Click for Details)</summary>
146
+
147
+ - base: opendilab/ding:nightly
148
+ - rpc: opendilab/ding:nightly-rpc
149
+ - atari: opendilab/ding:nightly-atari
150
+ - mujoco: opendilab/ding:nightly-mujoco
151
+ - dmc: opendilab/ding:nightly-dmc2gym
152
+ - metaworld: opendilab/ding:nightly-metaworld
153
+ - smac: opendilab/ding:nightly-smac
154
+ - grf: opendilab/ding:nightly-grf
155
+ - cityflow: opendilab/ding:nightly-cityflow
156
+ - evogym: opendilab/ding:nightly-evogym
157
+ - d4rl: opendilab/ding:nightly-d4rl
158
+ </details>
159
+
160
+ The detailed documentation are hosted on [doc](https://di-engine-docs.readthedocs.io/en/latest/) | [中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/).
161
+
162
+ ## Quick Start
163
+
164
+ [3 Minutes Kickoff](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/first_rl_program.html)
165
+
166
+ [3 Minutes Kickoff (colab)](https://colab.research.google.com/drive/1_7L-QFDfeCvMvLJzRyBRUW5_Q6ESXcZ4)
167
+
168
+ [DI-engine Huggingface Kickoff (colab)](https://colab.research.google.com/drive/1UH1GQOjcHrmNSaW77hnLGxFJrLSLwCOk)
169
+
170
+ [How to migrate a new **RL Env**](https://di-engine-docs.readthedocs.io/en/latest/11_dizoo/index.html) | [如何迁移一个新的**强化学习环境**](https://di-engine-docs.readthedocs.io/zh_CN/latest/11_dizoo/index_zh.html)
171
+
172
+ [How to customize the neural network model](https://di-engine-docs.readthedocs.io/en/latest/04_best_practice/custom_model.html) | [如何定制策略使用的**神经网络模型**](https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/custom_model_zh.html)
173
+
174
+ [测试/部署 **强化学习策略** 的样例](https://github.com/opendilab/DI-engine/blob/main/dizoo/classic_control/cartpole/entry/cartpole_c51_deploy.py)
175
+
176
+ [新老 pipeline 的异同对比](https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/diff_in_new_pipeline_zh.html)
177
+
178
+
179
+ ## Feature
180
+ ### Algorithm Versatility
181
+
182
+ <details open>
183
+ <summary>(Click to Collapse)</summary>
184
+
185
+ ![discrete](https://img.shields.io/badge/-discrete-brightgreen) &nbsp;discrete means discrete action space, which is only label in normal DRL algorithms (1-23)
186
+
187
+ ![continuous](https://img.shields.io/badge/-continous-green) &nbsp;means continuous action space, which is only label in normal DRL algorithms (1-23)
188
+
189
+ ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) &nbsp;means hybrid (discrete + continuous) action space (1-23)
190
+
191
+ ![dist](https://img.shields.io/badge/-distributed-blue) &nbsp;[Distributed Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/distributed_rl.html)|[分布式强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/distributed_rl_zh.html)
192
+
193
+ ![MARL](https://img.shields.io/badge/-MARL-yellow) &nbsp;[Multi-Agent Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/multi_agent_cooperation_rl.html)|[多智能体强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/multi_agent_cooperation_rl_zh.html)
194
+
195
+ ![exp](https://img.shields.io/badge/-exploration-orange) &nbsp;[Exploration Mechanisms in Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/exploration_rl.html)|[强化学习中的探索机制](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/exploration_rl_zh.html)
196
+
197
+ ![IL](https://img.shields.io/badge/-IL-purple) &nbsp;[Imitation Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/imitation_learning.html)|[模仿学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/imitation_learning_zh.html)
198
+
199
+ ![offline](https://img.shields.io/badge/-offlineRL-darkblue) &nbsp;[Offiline Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/offline_rl.html)|[离线强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/offline_rl_zh.html)
200
+
201
+
202
+ ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) &nbsp;[Model-Based Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/model_based_rl.html)|[基于模型的强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/model_based_rl_zh.html)
203
+
204
+ ![other](https://img.shields.io/badge/-other-lightgrey) &nbsp;means other sub-direction algorithms, usually as plugin-in in the whole pipeline
205
+
206
+ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
207
+
208
+
209
+
210
+ | No. | Algorithm | Label | Doc and Implementation | Runnable Demo |
211
+ | :--: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
212
+ | 1 | [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [DQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqn.html)<br>[DQN中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/dqn_zh.html)<br>[policy/dqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u cartpole_dqn_main.py / ding -m serial -c cartpole_dqn_config.py -s 0 |
213
+ | 2 | [C51](https://arxiv.org/pdf/1707.06887.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [C51 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/c51.html)<br>[policy/c51](https://github.com/opendilab/DI-engine/blob/main/ding/policy/c51.py) | ding -m serial -c cartpole_c51_config.py -s 0 |
214
+ | 3 | [QRDQN](https://arxiv.org/pdf/1710.10044.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [QRDQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qrdqn.html)<br>[policy/qrdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qrdqn.py) | ding -m serial -c cartpole_qrdqn_config.py -s 0 |
215
+ | 4 | [IQN](https://arxiv.org/pdf/1806.06923.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/iqn.html)<br>[policy/iqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/iqn.py) | ding -m serial -c cartpole_iqn_config.py -s 0 |
216
+ | 5 | [FQF](https://arxiv.org/pdf/1911.02140.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [FQF doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/fqf.html)<br>[policy/fqf](https://github.com/opendilab/DI-engine/blob/main/ding/policy/fqf.py) | ding -m serial -c cartpole_fqf_config.py -s 0 |
217
+ | 6 | [Rainbow](https://arxiv.org/pdf/1710.02298.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [Rainbow doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rainbow.html)<br>[policy/rainbow](https://github.com/opendilab/DI-engine/blob/main/ding/policy/rainbow.py) | ding -m serial -c cartpole_rainbow_config.py -s 0 |
218
+ | 7 | [SQL](https://arxiv.org/pdf/1702.08165.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [SQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sql.html)<br>[policy/sql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sql.py) | ding -m serial -c cartpole_sql_config.py -s 0 |
219
+ | 8 | [R2D2](https://openreview.net/forum?id=r1lyTjAqYX) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [R2D2 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d2.html)<br>[policy/r2d2](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d2.py) | ding -m serial -c cartpole_r2d2_config.py -s 0 |
220
+ | 9 | [PG](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)<br>[policy/pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pg.py) | ding -m serial -c cartpole_pg_config.py -s 0 |
221
+ | 10 | [PromptPG](https://arxiv.org/abs/2209.14610) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/prompt_pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_pg.py) | ding -m serial_onpolicy -c tabmwp_pg_config.py -s 0 |
222
+ | 11 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)<br>[policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 |
223
+ | 12 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html)<br>[policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 |
224
+ | 13 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html)<br>[policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py |
225
+ | 14 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html)<br>[policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 |
226
+ | 15 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html)<br>[policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 |
227
+ | 16 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 |
228
+ | 17 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html)<br>[policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 |
229
+ | 18 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html)<br>[policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
230
+ | 19 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html)<br>[policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 |
231
+ | 20 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 |
232
+ | 21 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 |
233
+ | 22 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 |
234
+ | 23 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
235
+ | 24 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py |
236
+ | 25 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html)<br>[policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
237
+ | 26 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html)<br>[policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
238
+ | 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
239
+ | 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)<br>[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
240
+ | 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)<br>[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
241
+ | 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
242
+ | 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)<br>[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
243
+ | 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)<br>[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
244
+ | 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)<br>[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
245
+ | 34 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html)<br>[R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html)<br>[policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py |
246
+ | 35 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html)<br>[reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py |
247
+ | 36 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html)<br>[reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py |
248
+ | 37 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py) <br> [model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py |
249
+ | 38 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py |
250
+ | 39 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html)<br>[reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
251
+ | 40 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html)<br>[reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py |
252
+ | 41 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html)<br>[ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html)<br>[reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
253
+ | 42 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)<br>[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
254
+ | 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)<br>[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
255
+ | 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u d4rl_dt_mujoco.py |
256
+ | 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)<br>[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
257
+ | 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
258
+ | 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
259
+ | 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)<br>[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
260
+ | 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
261
+ | 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
262
+ | 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
263
+ | 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
264
+ | 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
265
+ | 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)<br>[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
266
+ | 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
267
+ </details>
268
+
269
+
270
+ ### Environment Versatility
271
+ <details open>
272
+ <summary>(Click to Collapse)</summary>
273
+
274
+ | No | Environment | Label | Visualization | Code and Doc Links |
275
+ | :--: | :--------------------------------------: | :---------------------------------: | :--------------------------------:|:---------------------------------------------------------: |
276
+ | 1 | [Atari](https://github.com/openai/gym/tree/master/gym/envs/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/atari/atari.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/atari/envs) <br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/atari_zh.html) |
277
+ | 2 | [box2d/bipedalwalker](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/bipedalwalker/original.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/bipedalwalker/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/bipedalwalker.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bipedalwalker_zh.html) |
278
+ | 3 | [box2d/lunarlander](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/box2d/lunarlander/lunarlander.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/lunarlander/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/lunarlander.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/lunarlander_zh.html) |
279
+ | 4 | [classic_control/cartpole](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/cartpole/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/cartpole.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/cartpole_zh.html) |
280
+ | 5 | [classic_control/pendulum](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/classic_control/pendulum/pendulum.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/pendulum/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pendulum.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pendulum_zh.html) |
281
+ | 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/competitive_rl_zh.html) |
282
+ | 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gfootball.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gfootball_zh.html) |
283
+ | 8 | [minigrid](https://github.com/maximecb/gym-minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/minigrid/minigrid.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/minigrid/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid_zh.html) |
284
+ | 9 | [MuJoCo](https://github.com/openai/gym/tree/master/gym/envs/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/majoco/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco_zh.html) |
285
+ | 10 | [PettingZoo](https://github.com/Farama-Foundation/PettingZoo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/petting_zoo/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pettingzoo.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pettingzoo_zh.html) |
286
+ | 11 | [overcooked](https://github.com/HumanCompatibleAI/overcooked-demo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/overcooked/overcooked.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/overcooded/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/overcooked.html) |
287
+ | 12 | [procgen](https://github.com/openai/procgen) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/procgen/coinrun.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/procgen)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/procgen.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/procgen_zh.html) |
288
+ | 13 | [pybullet](https://github.com/benelot/pybullet-gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/pybullet/pybullet.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pybullet/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pybullet_zh.html) |
289
+ | 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/smac.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/smac_zh.html) |
290
+ | 15 | [d4rl](https://github.com/rail-berkeley/d4rl) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | ![ori](dizoo/d4rl/d4rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/d4rl)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/d4rl_zh.html) |
291
+ | 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) |
292
+ | 17 | pomdp atari | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pomdp/envs) |
293
+ | 18 | [bsuite](https://github.com/deepmind/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/bsuite/bsuite.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bsuite/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs//bsuite.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bsuite_zh.html) |
294
+ | 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/image_cls_zh.html) |
295
+ | 20 | [slime_volleyball](https://github.com/hardmaru/slimevolleygym) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](dizoo/slime_volley/slime_volley.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/slime_volley)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/slime_volleyball.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/slime_volleyball_zh.html) |
296
+ | 21 | [gym_hybrid](https://github.com/thomashirtz/gym-hybrid) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_hybrid/moving_v0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_hybrid)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gym_hybrid.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_hybrid_zh.html) |
297
+ | 22 | [GoBigger](https://github.com/opendilab/GoBigger) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen)![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](./dizoo/gobigger_overview.gif) | [dizoo link](https://github.com/opendilab/GoBigger-Challenge-2021/tree/main/di_baseline)<br>[env tutorial](https://gobigger.readthedocs.io/en/latest/index.html)<br>[环境指南](https://gobigger.readthedocs.io/zh_CN/latest/) |
298
+ | 23 | [gym_soccer](https://github.com/openai/gym-soccer) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_soccer/half_offensive.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_soccer)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_soccer_zh.html) |
299
+ | 24 |[multiagent_mujoco](https://github.com/schroederdewitt/multiagent_mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_mujoco/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/mujoco_zh.html) |
300
+ | 25 |bitflip | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/bitflip/bitflip.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bitflip/envs)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bitflip_zh.html) |
301
+ | 26 |[sokoban](https://github.com/mpSchrader/gym-sokoban) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![Game 2](https://github.com/mpSchrader/gym-sokoban/raw/default/docs/Animations/solved_4.gif?raw=true) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/sokoban/envs)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/sokoban.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/sokoban_zh.html) |
302
+ | 27 |[gym_anytrading](https://github.com/AminHP/gym-anytrading) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/gym_anytrading/envs/position.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_anytrading) <br> [env tutorial](https://github.com/opendilab/DI-engine/blob/main/dizoo/gym_anytrading/envs/README.md) |
303
+ | 28 |[mario](https://github.com/Kautenja/gym-super-mario-bros) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/mario/mario.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/mario) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gym_super_mario_bros.html) <br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_super_mario_bros_zh.html) |
304
+ | 29 |[dmc2gym](https://github.com/denisyarats/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/dmc2gym/dmc2gym_cheetah.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/dmc2gym)<br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/dmc2gym.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/dmc2gym_zh.html) |
305
+ | 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/evogym.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/Evogym_zh.html) |
306
+ | 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym_pybullet_drones/gym_pybullet_drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)<br>环境指南 |
307
+ | 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)<br>环境指南 |
308
+ | 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/acrobot_zh.html) |
309
+ | 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) <br> ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)<br>环境指南 |
310
+ | 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
311
+ | 36 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)<br> env tutorial <br> 环境指南 |
312
+ | 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南|
313
+
314
+ ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
315
+
316
+ ![continuous](https://img.shields.io/badge/-continous-green) means continuous action space
317
+
318
+ ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) means hybrid (discrete + continuous) action space
319
+
320
+ ![MARL](https://img.shields.io/badge/-MARL-yellow) means multi-agent RL environment
321
+
322
+ ![sparse](https://img.shields.io/badge/-sparse%20reward-orange) means environment which is related to exploration and sparse reward
323
+
324
+ ![offline](https://img.shields.io/badge/-offlineRL-darkblue) means offline RL environment
325
+
326
+ ![IL](https://img.shields.io/badge/-IL/SL-purple) means Imitation Learning or Supervised Learning Dataset
327
+
328
+ ![selfplay](https://img.shields.io/badge/-selfplay-blue) means environment that allows agent VS agent battle
329
+
330
+ P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also the sparse reward type.
331
+ </details>
332
+
333
+
334
+ ### General Data Container: TreeTensor
335
+
336
+ DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as the basic data container in various components, which is ease of use and consistent across different code modules such as environment definition, data processing and DRL optimization. Here are some concrete code examples:
337
+
338
+ - TreeTensor can easily extend all the operations of `torch.Tensor` to nested data:
339
+ <details close>
340
+ <summary>(Click for Details)</summary>
341
+
342
+ ```python
343
+ import treetensor.torch as ttorch
344
+
345
+
346
+ # create random tensor
347
+ data = ttorch.randn({'a': (3, 2), 'b': {'c': (3, )}})
348
+ # clone+detach tensor
349
+ data_clone = data.clone().detach()
350
+ # access tree structure like attribute
351
+ a = data.a
352
+ c = data.b.c
353
+ # stack/cat/split
354
+ stacked_data = ttorch.stack([data, data_clone], 0)
355
+ cat_data = ttorch.cat([data, data_clone], 0)
356
+ data, data_clone = ttorch.split(stacked_data, 1)
357
+ # reshape
358
+ data = data.unsqueeze(-1)
359
+ data = data.squeeze(-1)
360
+ flatten_data = data.view(-1)
361
+ # indexing
362
+ data_0 = data[0]
363
+ data_1to2 = data[1:2]
364
+ # execute math calculations
365
+ data = data.sin()
366
+ data.b.c.cos_().clamp_(-1, 1)
367
+ data += data ** 2
368
+ # backward
369
+ data.requires_grad_(True)
370
+ loss = data.arctan().mean()
371
+ loss.backward()
372
+ # print shape
373
+ print(data.shape)
374
+ # result
375
+ # <Size 0x7fbd3346ddc0>
376
+ # ├── 'a' --> torch.Size([1, 3, 2])
377
+ # └── 'b' --> <Size 0x7fbd3346dd00>
378
+ # └── 'c' --> torch.Size([1, 3])
379
+ ```
380
+
381
+ </details>
382
+
383
+ - TreeTensor can make it simple yet effective to implement classic deep reinforcement learning pipeline
384
+ <details close>
385
+ <summary>(Click for Details)</summary>
386
+
387
+ ```diff
388
+ import torch
389
+ import treetensor.torch as ttorch
390
+
391
+ B = 4
392
+
393
+
394
+ def get_item():
395
+ return {
396
+ 'obs': {
397
+ 'scalar': torch.randn(12),
398
+ 'image': torch.randn(3, 32, 32),
399
+ },
400
+ 'action': torch.randint(0, 10, size=(1,)),
401
+ 'reward': torch.rand(1),
402
+ 'done': False,
403
+ }
404
+
405
+
406
+ data = [get_item() for _ in range(B)]
407
+
408
+
409
+ # execute `stack` op
410
+ - def stack(data, dim):
411
+ - elem = data[0]
412
+ - if isinstance(elem, torch.Tensor):
413
+ - return torch.stack(data, dim)
414
+ - elif isinstance(elem, dict):
415
+ - return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
416
+ - elif isinstance(elem, bool):
417
+ - return torch.BoolTensor(data)
418
+ - else:
419
+ - raise TypeError("not support elem type: {}".format(type(elem)))
420
+ - stacked_data = stack(data, dim=0)
421
+ + data = [ttorch.tensor(d) for d in data]
422
+ + stacked_data = ttorch.stack(data, dim=0)
423
+
424
+ # validate
425
+ - assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
426
+ - assert stacked_data['action'].shape == (B, 1)
427
+ - assert stacked_data['reward'].shape == (B, 1)
428
+ - assert stacked_data['done'].shape == (B,)
429
+ - assert stacked_data['done'].dtype == torch.bool
430
+ + assert stacked_data.obs.image.shape == (B, 3, 32, 32)
431
+ + assert stacked_data.action.shape == (B, 1)
432
+ + assert stacked_data.reward.shape == (B, 1)
433
+ + assert stacked_data.done.shape == (B,)
434
+ + assert stacked_data.done.dtype == torch.bool
435
+ ```
436
+
437
+ </details>
438
+
439
+ ## Feedback and Contribution
440
+
441
+ - [File an issue](https://github.com/opendilab/DI-engine/issues/new/choose) on Github
442
+ - Open or participate in our [forum](https://github.com/opendilab/DI-engine/discussions)
443
+ - Discuss on DI-engine [slack communication channel](https://join.slack.com/t/opendilab/shared_invite/zt-v9tmv4fp-nUBAQEH1_Kuyu_q4plBssQ)
444
+ - Discuss on DI-engine's WeChat group (i.e. add us on WeChat: ding314assist)
445
+
446
+ <img src=https://github.com/opendilab/DI-engine/blob/main/assets/wechat.jpeg width=35% />
447
+ - Contact our email (opendilab@pjlab.org.cn)
448
+ - Contributes to our future plan [Roadmap](https://github.com/opendilab/DI-engine/issues/548)
449
+
450
+ We appreciate all the feedbacks and contributions to improve DI-engine, both algorithms and system designs. And `CONTRIBUTING.md` offers some necessary information.
451
+
452
+ ## Supporters
453
+
454
+ ### &#8627; Stargazers
455
+
456
+ [![Stargazers repo roster for @opendilab/DI-engine](https://reporoster.com/stars/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/stargazers)
457
+
458
+ ### &#8627; Forkers
459
+
460
+ [![Forkers repo roster for @opendilab/DI-engine](https://reporoster.com/forks/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/network/members)
461
+
462
+
463
+ ## Citation
464
+ ```latex
465
+ @misc{ding,
466
+ title={DI-engine: OpenDILab Decision Intelligence Engine},
467
+ author={OpenDILab Contributors},
468
+ publisher={GitHub},
469
+ howpublished={\url{https://github.com/opendilab/DI-engine}},
470
+ year={2021},
471
+ }
472
+ ```
473
+
474
+ ## License
475
+ DI-engine released under the Apache 2.0 license.
DI-engine/cloc.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This scripts counts the lines of code and comments in all source files
4
+ # and prints the results to the command line. It uses the commandline tool
5
+ # "cloc". You can either pass --loc, --comments or --percentage to show the
6
+ # respective values only.
7
+ # Some parts below need to be adapted to your project!
8
+
9
+ # Get the location of this script.
10
+ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
11
+
12
+ # Run cloc - this counts code lines, blank lines and comment lines
13
+ # for the specified languages. You will need to change this accordingly.
14
+ # For C++, you could use "C++,C/C++ Header" for example.
15
+ # We are only interested in the summary, therefore the tail -1
16
+ SUMMARY="$(cloc "${SCRIPT_DIR}" --include-lang="Python" --md | tail -1)"
17
+
18
+ # The $SUMMARY is one line of a markdown table and looks like this:
19
+ # SUM:|101|3123|2238|10783
20
+ # We use the following command to split it into an array.
21
+ IFS='|' read -r -a TOKENS <<< "$SUMMARY"
22
+
23
+ # Store the individual tokens for better readability.
24
+ NUMBER_OF_FILES=${TOKENS[1]}
25
+ COMMENT_LINES=${TOKENS[3]}
26
+ LINES_OF_CODE=${TOKENS[4]}
27
+
28
+ # To make the estimate of commented lines more accurate, we have to
29
+ # subtract any copyright header which is included in each file.
30
+ # For Fly-Pie, this header has the length of five lines.
31
+ # All dumb comments like those /////////// or those // ------------
32
+ # are also subtracted. As cloc does not count inline comments,
33
+ # the overall estimate should be rather conservative.
34
+ # Change the lines below according to your project.
35
+ DUMB_COMMENTS="$(grep -r -E '//////|// -----' "${SCRIPT_DIR}" | wc -l)"
36
+ COMMENT_LINES=$(($COMMENT_LINES - 5 * $NUMBER_OF_FILES - $DUMB_COMMENTS))
37
+
38
+ # Print all results if no arguments are given.
39
+ if [[ $# -eq 0 ]] ; then
40
+ awk -v a=$LINES_OF_CODE \
41
+ 'BEGIN {printf "Lines of source code: %6.1fk\n", a/1000}'
42
+ awk -v a=$COMMENT_LINES \
43
+ 'BEGIN {printf "Lines of comments: %6.1fk\n", a/1000}'
44
+ awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
45
+ 'BEGIN {printf "Comment Percentage: %6.1f%\n", 100*a/b}'
46
+ exit 0
47
+ fi
48
+
49
+ # Show lines of code if --loc is given.
50
+ if [[ $* == *--loc* ]]
51
+ then
52
+ awk -v a=$LINES_OF_CODE \
53
+ 'BEGIN {printf "%.1fk\n", a/1000}'
54
+ fi
55
+
56
+ # Show lines of comments if --comments is given.
57
+ if [[ $* == *--comments* ]]
58
+ then
59
+ awk -v a=$COMMENT_LINES \
60
+ 'BEGIN {printf "%.1fk\n", a/1000}'
61
+ fi
62
+
63
+ # Show precentage of comments if --percentage is given.
64
+ if [[ $* == *--percentage* ]]
65
+ then
66
+ awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
67
+ 'BEGIN {printf "%.1f\n", 100*a/b}'
68
+ fi
69
+
DI-engine/codecov.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ coverage:
2
+ status:
3
+ project:
4
+ default:
5
+ # basic
6
+ target: auto
7
+ threshold: 0.5%
8
+ if_ci_failed: success #success, failure, error, ignore
DI-engine/conda/conda_build_config.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python:
2
+ - 3.7
DI-engine/conda/meta.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% set data = load_setup_py_data() %}
2
+ package:
3
+ name: di-engine
4
+ version: v0.5.0
5
+
6
+ source:
7
+ path: ..
8
+
9
+ build:
10
+ number: 0
11
+ script: python -m pip install . -vv
12
+ entry_points:
13
+ - ding = ding.entry.cli:cli
14
+
15
+ requirements:
16
+ build:
17
+ - python
18
+ - setuptools
19
+ run:
20
+ - python
21
+
22
+ test:
23
+ imports:
24
+ - ding
25
+ - dizoo
26
+
27
+ about:
28
+ home: https://github.com/opendilab/DI-engine
29
+ license: Apache-2.0
30
+ license_file: LICENSE
31
+ summary: DI-engine is a generalized Decision Intelligence engine (https://github.com/opendilab/DI-engine).
32
+ description: Please refer to https://di-engine-docs.readthedocs.io/en/latest/00_intro/index.html#what-is-di-engine
33
+ dev_url: https://github.com/opendilab/DI-engine
34
+ doc_url: https://di-engine-docs.readthedocs.io/en/latest/index.html
35
+ doc_source_url: https://github.com/opendilab/DI-engine-docs
DI-engine/ding/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ __TITLE__ = 'DI-engine'
4
+ __VERSION__ = 'v0.5.0'
5
+ __DESCRIPTION__ = 'Decision AI Engine'
6
+ __AUTHOR__ = "OpenDILab Contributors"
7
+ __AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn"
8
+ __version__ = __VERSION__
9
+
10
+ enable_hpc_rl = os.environ.get('ENABLE_DI_HPC', 'false').lower() == 'true'
11
+ enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true'
12
+ enable_numba = True
DI-engine/ding/bonus/__init__.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ding.config
2
+ from .a2c import A2CAgent
3
+ from .c51 import C51Agent
4
+ from .ddpg import DDPGAgent
5
+ from .dqn import DQNAgent
6
+ from .pg import PGAgent
7
+ from .ppof import PPOF
8
+ from .ppo_offpolicy import PPOOffPolicyAgent
9
+ from .sac import SACAgent
10
+ from .sql import SQLAgent
11
+ from .td3 import TD3Agent
12
+
13
+ supported_algo = dict(
14
+ A2C=A2CAgent,
15
+ C51=C51Agent,
16
+ DDPG=DDPGAgent,
17
+ DQN=DQNAgent,
18
+ PG=PGAgent,
19
+ PPOF=PPOF,
20
+ PPOOffPolicy=PPOOffPolicyAgent,
21
+ SAC=SACAgent,
22
+ SQL=SQLAgent,
23
+ TD3=TD3Agent,
24
+ )
25
+
26
+ supported_algo_list = list(supported_algo.keys())
27
+
28
+
29
+ def env_supported(algo: str = None) -> list:
30
+ """
31
+ return list of the envs that supported by di-engine.
32
+ """
33
+
34
+ if algo is not None:
35
+ if algo.upper() == "A2C":
36
+ return list(ding.config.example.A2C.supported_env.keys())
37
+ elif algo.upper() == "C51":
38
+ return list(ding.config.example.C51.supported_env.keys())
39
+ elif algo.upper() == "DDPG":
40
+ return list(ding.config.example.DDPG.supported_env.keys())
41
+ elif algo.upper() == "DQN":
42
+ return list(ding.config.example.DQN.supported_env.keys())
43
+ elif algo.upper() == "PG":
44
+ return list(ding.config.example.PG.supported_env.keys())
45
+ elif algo.upper() == "PPOF":
46
+ return list(ding.config.example.PPOF.supported_env.keys())
47
+ elif algo.upper() == "PPOOFFPOLICY":
48
+ return list(ding.config.example.PPOOffPolicy.supported_env.keys())
49
+ elif algo.upper() == "SAC":
50
+ return list(ding.config.example.SAC.supported_env.keys())
51
+ elif algo.upper() == "SQL":
52
+ return list(ding.config.example.SQL.supported_env.keys())
53
+ elif algo.upper() == "TD3":
54
+ return list(ding.config.example.TD3.supported_env.keys())
55
+ else:
56
+ raise ValueError("The algo {} is not supported by di-engine.".format(algo))
57
+ else:
58
+ supported_env = set()
59
+ supported_env.update(ding.config.example.A2C.supported_env.keys())
60
+ supported_env.update(ding.config.example.C51.supported_env.keys())
61
+ supported_env.update(ding.config.example.DDPG.supported_env.keys())
62
+ supported_env.update(ding.config.example.DQN.supported_env.keys())
63
+ supported_env.update(ding.config.example.PG.supported_env.keys())
64
+ supported_env.update(ding.config.example.PPOF.supported_env.keys())
65
+ supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys())
66
+ supported_env.update(ding.config.example.SAC.supported_env.keys())
67
+ supported_env.update(ding.config.example.SQL.supported_env.keys())
68
+ supported_env.update(ding.config.example.TD3.supported_env.keys())
69
+ # return the list of the envs
70
+ return list(supported_env)
71
+
72
+
73
+ supported_env = env_supported()
74
+
75
+
76
+ def algo_supported(env_id: str = None) -> list:
77
+ """
78
+ return list of the algos that supported by di-engine.
79
+ """
80
+ if env_id is not None:
81
+ algo = []
82
+ if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]:
83
+ algo.append("A2C")
84
+ if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]:
85
+ algo.append("C51")
86
+ if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]:
87
+ algo.append("DDPG")
88
+ if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]:
89
+ algo.append("DQN")
90
+ if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]:
91
+ algo.append("PG")
92
+ if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]:
93
+ algo.append("PPOF")
94
+ if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]:
95
+ algo.append("PPOOffPolicy")
96
+ if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]:
97
+ algo.append("SAC")
98
+ if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]:
99
+ algo.append("SQL")
100
+ if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]:
101
+ algo.append("TD3")
102
+
103
+ if len(algo) == 0:
104
+ raise ValueError("The env {} is not supported by di-engine.".format(env_id))
105
+ return algo
106
+ else:
107
+ return supported_algo_list
108
+
109
+
110
+ def is_supported(env_id: str = None, algo: str = None) -> bool:
111
+ """
112
+ Check if the env-algo pair is supported by di-engine.
113
+ """
114
+ if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]:
115
+ if algo is not None and algo.upper() in supported_algo_list:
116
+ if env_id.upper() in env_supported(algo):
117
+ return True
118
+ else:
119
+ return False
120
+ elif algo is None:
121
+ return True
122
+ else:
123
+ return False
124
+ elif env_id is None:
125
+ if algo is not None and algo.upper() in supported_algo_list:
126
+ return True
127
+ elif algo is None:
128
+ raise ValueError("Please specify the env or algo.")
129
+ else:
130
+ return False
131
+ else:
132
+ return False
DI-engine/ding/bonus/a2c.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, trainer, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
11
+ gae_estimator, final_ctx_saver
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import A2CPolicy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import VAC
19
+ from ding.model import model_wrap
20
+ from ding.bonus.common import TrainingReturn, EvalReturn
21
+ from ding.config.example.A2C import supported_env_cfg
22
+ from ding.config.example.A2C import supported_env
23
+
24
+
25
+ class A2CAgent:
26
+ """
27
+ Overview:
28
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29
+ Advantage Actor Critic(A2C).
30
+ For more information about the system design of RL agent, please refer to \
31
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
32
+ Interface:
33
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34
+ """
35
+ supported_env_list = list(supported_env_cfg.keys())
36
+ """
37
+ Overview:
38
+ List of supported envs.
39
+ Examples:
40
+ >>> from ding.bonus.a2c import A2CAgent
41
+ >>> print(A2CAgent.supported_env_list)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_id: str = None,
47
+ env: BaseEnv = None,
48
+ seed: int = 0,
49
+ exp_name: str = None,
50
+ model: Optional[torch.nn.Module] = None,
51
+ cfg: Optional[Union[EasyDict, dict]] = None,
52
+ policy_state_dict: str = None,
53
+ ) -> None:
54
+ """
55
+ Overview:
56
+ Initialize agent for A2C algorithm.
57
+ Arguments:
58
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
67
+ Default to 0.
68
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70
+ - model (:obj:`torch.nn.Module`): The model of A2C algorithm, which should be an instance of class \
71
+ :class:`ding.model.VAC`. \
72
+ If not specified, a default model will be generated according to the configuration.
73
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of A2C algorithm, which is a dict. \
74
+ Default to None. If not specified, the default configuration will be used. \
75
+ The default configuration can be found in ``ding/config/example/A2C/gym_lunarlander_v2.py``.
76
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77
+ If specified, the policy will be loaded from this file. Default to None.
78
+
79
+ .. note::
80
+ An RL Agent Instance can be initialized in two basic ways. \
81
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
82
+ and we want to train an agent with A2C algorithm with default configuration. \
83
+ Then we can initialize the agent in the following ways:
84
+ >>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
85
+ or, if we want can specify the env_id in the configuration:
86
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
87
+ >>> agent = A2CAgent(cfg=cfg)
88
+ There are also other arguments to specify the agent when initializing.
89
+ For example, if we want to specify the environment instance:
90
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
91
+ >>> agent = A2CAgent(cfg=cfg, env=env)
92
+ or, if we want to specify the model:
93
+ >>> model = VAC(**cfg.policy.model)
94
+ >>> agent = A2CAgent(cfg=cfg, model=model)
95
+ or, if we want to reload the policy from a saved policy state dict:
96
+ >>> agent = A2CAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
97
+ Make sure that the configuration is consistent with the saved policy state dict.
98
+ """
99
+
100
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101
+
102
+ if cfg is not None and not isinstance(cfg, EasyDict):
103
+ cfg = EasyDict(cfg)
104
+
105
+ if env_id is not None:
106
+ assert env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
107
+ A2CAgent.supported_env_list
108
+ )
109
+ if cfg is None:
110
+ cfg = supported_env_cfg[env_id]
111
+ else:
112
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113
+ else:
114
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115
+ assert cfg.env.env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
116
+ A2CAgent.supported_env_list
117
+ )
118
+ default_policy_config = EasyDict({"policy": A2CPolicy.default_config()})
119
+ default_policy_config.update(cfg)
120
+ cfg = default_policy_config
121
+
122
+ if exp_name is not None:
123
+ cfg.exp_name = exp_name
124
+ self.cfg = compile_config(cfg, policy=A2CPolicy)
125
+ self.exp_name = self.cfg.exp_name
126
+ if env is None:
127
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128
+ else:
129
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130
+ self.env = env
131
+
132
+ logging.getLogger().setLevel(logging.INFO)
133
+ self.seed = seed
134
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135
+ if not os.path.exists(self.exp_name):
136
+ os.makedirs(self.exp_name)
137
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
138
+ if model is None:
139
+ model = VAC(**self.cfg.policy.model)
140
+ self.policy = A2CPolicy(self.cfg.policy, model=model)
141
+ if policy_state_dict is not None:
142
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
143
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
144
+
145
+ def train(
146
+ self,
147
+ step: int = int(1e7),
148
+ collector_env_num: int = 4,
149
+ evaluator_env_num: int = 4,
150
+ n_iter_log_show: int = 500,
151
+ n_iter_save_ckpt: int = 1000,
152
+ context: Optional[str] = None,
153
+ debug: bool = False,
154
+ wandb_sweep: bool = False,
155
+ ) -> TrainingReturn:
156
+ """
157
+ Overview:
158
+ Train the agent with A2C algorithm for ``step`` iterations with ``collector_env_num`` collector \
159
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160
+ recorded and saved by wandb.
161
+ Arguments:
162
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164
+ If not specified, it will be set according to the configuration.
165
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166
+ If not specified, it will be set according to the configuration.
167
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168
+ Default to 1000.
169
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
173
+ subprocess environment manager will be used.
174
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175
+ which is a hyper-parameter optimization process for seeking the best configurations. \
176
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
177
+ Returns:
178
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
179
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180
+ """
181
+
182
+ if debug:
183
+ logging.getLogger().setLevel(logging.DEBUG)
184
+ logging.debug(self.policy._model)
185
+ # define env and policy
186
+ collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
187
+ evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
188
+
189
+ with task.start(ctx=OnlineRLContext()):
190
+ task.use(
191
+ interaction_evaluator(
192
+ self.cfg,
193
+ self.policy.eval_mode,
194
+ evaluator_env,
195
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
196
+ )
197
+ )
198
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
199
+ task.use(
200
+ StepCollector(
201
+ self.cfg,
202
+ self.policy.collect_mode,
203
+ collector_env,
204
+ random_collect_size=self.cfg.policy.random_collect_size
205
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
206
+ )
207
+ )
208
+ task.use(gae_estimator(self.cfg, self.policy.collect_mode))
209
+ task.use(trainer(self.cfg, self.policy.learn_mode))
210
+ task.use(
211
+ wandb_online_logger(
212
+ metric_list=self.policy._monitor_vars_learn(),
213
+ model=self.policy._model,
214
+ anonymous=True,
215
+ project_name=self.exp_name,
216
+ wandb_sweep=wandb_sweep,
217
+ )
218
+ )
219
+ task.use(termination_checker(max_env_step=step))
220
+ task.use(final_ctx_saver(name=self.exp_name))
221
+ task.run()
222
+
223
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
224
+
225
+ def deploy(
226
+ self,
227
+ enable_save_replay: bool = False,
228
+ concatenate_all_replay: bool = False,
229
+ replay_save_path: str = None,
230
+ seed: Optional[Union[int, List]] = None,
231
+ debug: bool = False
232
+ ) -> EvalReturn:
233
+ """
234
+ Overview:
235
+ Deploy the agent with A2C algorithm by interacting with the environment, during which the replay video \
236
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
237
+ Arguments:
238
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
239
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
240
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
241
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
242
+ the replay video of each episode will be saved separately.
243
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
244
+ If not specified, the video will be saved in ``exp_name/videos``.
245
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
246
+ Default to None. If not specified, ``self.seed`` will be used. \
247
+ If ``seed`` is an integer, the agent will be deployed once. \
248
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
249
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
250
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
251
+ subprocess environment manager will be used.
252
+ Returns:
253
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
254
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
255
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
256
+ """
257
+
258
+ if debug:
259
+ logging.getLogger().setLevel(logging.DEBUG)
260
+ # define env and policy
261
+ env = self.env.clone(caller='evaluator')
262
+
263
+ if seed is not None and isinstance(seed, int):
264
+ seeds = [seed]
265
+ elif seed is not None and isinstance(seed, list):
266
+ seeds = seed
267
+ else:
268
+ seeds = [self.seed]
269
+
270
+ returns = []
271
+ images = []
272
+ if enable_save_replay:
273
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
274
+ env.enable_save_replay(replay_path=replay_save_path)
275
+ else:
276
+ logging.warning('No video would be generated during the deploy.')
277
+ if concatenate_all_replay:
278
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
279
+ concatenate_all_replay = False
280
+
281
+ def single_env_forward_wrapper(forward_fn, cuda=True):
282
+
283
+ if self.cfg.policy.action_space == 'continuous':
284
+ forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
285
+ elif self.cfg.policy.action_space == 'discrete':
286
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
287
+ else:
288
+ raise NotImplementedError
289
+
290
+ def _forward(obs):
291
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
292
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
293
+ if cuda and torch.cuda.is_available():
294
+ obs = obs.cuda()
295
+ action = forward_fn(obs, mode='compute_actor')["action"]
296
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
297
+ action = action.squeeze(0).detach().cpu().numpy()
298
+ return action
299
+
300
+ return _forward
301
+
302
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
303
+
304
+ # reset first to make sure the env is in the initial state
305
+ # env will be reset again in the main loop
306
+ env.reset()
307
+
308
+ for seed in seeds:
309
+ env.seed(seed, dynamic_seed=False)
310
+ return_ = 0.
311
+ step = 0
312
+ obs = env.reset()
313
+ images.append(render(env)[None]) if concatenate_all_replay else None
314
+ while True:
315
+ action = forward_fn(obs)
316
+ obs, rew, done, info = env.step(action)
317
+ images.append(render(env)[None]) if concatenate_all_replay else None
318
+ return_ += rew
319
+ step += 1
320
+ if done:
321
+ break
322
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
323
+ returns.append(return_)
324
+
325
+ env.close()
326
+
327
+ if concatenate_all_replay:
328
+ images = np.concatenate(images, axis=0)
329
+ import imageio
330
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
331
+
332
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
333
+
334
+ def collect_data(
335
+ self,
336
+ env_num: int = 8,
337
+ save_data_path: Optional[str] = None,
338
+ n_sample: Optional[int] = None,
339
+ n_episode: Optional[int] = None,
340
+ context: Optional[str] = None,
341
+ debug: bool = False
342
+ ) -> None:
343
+ """
344
+ Overview:
345
+ Collect data with A2C algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
346
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
347
+ ``exp_name/demo_data``.
348
+ Arguments:
349
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
350
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
351
+ If not specified, the data will be saved in ``exp_name/demo_data``.
352
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
353
+ If not specified, ``n_episode`` must be specified.
354
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
355
+ If not specified, ``n_sample`` must be specified.
356
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
357
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
358
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
359
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
360
+ subprocess environment manager will be used.
361
+ """
362
+
363
+ if debug:
364
+ logging.getLogger().setLevel(logging.DEBUG)
365
+ if n_episode is not None:
366
+ raise NotImplementedError
367
+ # define env and policy
368
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
369
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
370
+
371
+ if save_data_path is None:
372
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
373
+
374
+ # main execution task
375
+ with task.start(ctx=OnlineRLContext()):
376
+ task.use(
377
+ StepCollector(
378
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
379
+ )
380
+ )
381
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
382
+ task.run(max_step=1)
383
+ logging.info(
384
+ f'A2C collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
385
+ )
386
+
387
+ def batch_evaluate(
388
+ self,
389
+ env_num: int = 4,
390
+ n_evaluator_episode: int = 4,
391
+ context: Optional[str] = None,
392
+ debug: bool = False
393
+ ) -> EvalReturn:
394
+ """
395
+ Overview:
396
+ Evaluate the agent with A2C algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
397
+ environments. The evaluation result will be returned.
398
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
399
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
400
+ will only create one evaluator environment to evaluate the agent and save the replay video.
401
+ Arguments:
402
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
403
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
404
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
405
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
406
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
407
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
408
+ subprocess environment manager will be used.
409
+ Returns:
410
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
411
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
412
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
413
+ """
414
+
415
+ if debug:
416
+ logging.getLogger().setLevel(logging.DEBUG)
417
+ # define env and policy
418
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
419
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
420
+
421
+ # reset first to make sure the env is in the initial state
422
+ # env will be reset again in the main loop
423
+ env.launch()
424
+ env.reset()
425
+
426
+ evaluate_cfg = self.cfg
427
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
428
+
429
+ # main execution task
430
+ with task.start(ctx=OnlineRLContext()):
431
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
432
+ task.run(max_step=1)
433
+
434
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
435
+
436
+ @property
437
+ def best(self) -> 'A2CAgent':
438
+ """
439
+ Overview:
440
+ Load the best model from the checkpoint directory, \
441
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
442
+ The return value is the agent with the best model.
443
+ Returns:
444
+ - (:obj:`A2CAgent`): The agent with the best model.
445
+ Examples:
446
+ >>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
447
+ >>> agent.train()
448
+ >>> agent = agent.best
449
+
450
+ .. note::
451
+ The best model is the model with the highest evaluation return. If this method is called, the current \
452
+ model will be replaced by the best model.
453
+ """
454
+
455
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
456
+ # Load best model if it exists
457
+ if os.path.exists(best_model_file_path):
458
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
459
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
460
+ return self
DI-engine/ding/bonus/c51.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11
+ OffPolicyLearner, final_ctx_saver, eps_greedy_handler, nstep_reward_enhancer
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import C51Policy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import C51DQN
19
+ from ding.model import model_wrap
20
+ from ding.data import DequeBuffer
21
+ from ding.bonus.common import TrainingReturn, EvalReturn
22
+ from ding.config.example.C51 import supported_env_cfg
23
+ from ding.config.example.C51 import supported_env
24
+
25
+
26
+ class C51Agent:
27
+ """
28
+ Overview:
29
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm C51.
30
+ For more information about the system design of RL agent, please refer to \
31
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
32
+ Interface:
33
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34
+ """
35
+ supported_env_list = list(supported_env_cfg.keys())
36
+ """
37
+ Overview:
38
+ List of supported envs.
39
+ Examples:
40
+ >>> from ding.bonus.c51 import C51Agent
41
+ >>> print(C51Agent.supported_env_list)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_id: str = None,
47
+ env: BaseEnv = None,
48
+ seed: int = 0,
49
+ exp_name: str = None,
50
+ model: Optional[torch.nn.Module] = None,
51
+ cfg: Optional[Union[EasyDict, dict]] = None,
52
+ policy_state_dict: str = None,
53
+ ) -> None:
54
+ """
55
+ Overview:
56
+ Initialize agent for C51 algorithm.
57
+ Arguments:
58
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
67
+ Default to 0.
68
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70
+ - model (:obj:`torch.nn.Module`): The model of C51 algorithm, which should be an instance of class \
71
+ :class:`ding.model.C51DQN`. \
72
+ If not specified, a default model will be generated according to the configuration.
73
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of C51 algorithm, which is a dict. \
74
+ Default to None. If not specified, the default configuration will be used. \
75
+ The default configuration can be found in ``ding/config/example/C51/gym_lunarlander_v2.py``.
76
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77
+ If specified, the policy will be loaded from this file. Default to None.
78
+
79
+ .. note::
80
+ An RL Agent Instance can be initialized in two basic ways. \
81
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
82
+ and we want to train an agent with C51 algorithm with default configuration. \
83
+ Then we can initialize the agent in the following ways:
84
+ >>> agent = C51Agent(env_id='LunarLander-v2')
85
+ or, if we want can specify the env_id in the configuration:
86
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
87
+ >>> agent = C51Agent(cfg=cfg)
88
+ There are also other arguments to specify the agent when initializing.
89
+ For example, if we want to specify the environment instance:
90
+ >>> env = CustomizedEnv('LunarLander-v2')
91
+ >>> agent = C51Agent(cfg=cfg, env=env)
92
+ or, if we want to specify the model:
93
+ >>> model = C51DQN(**cfg.policy.model)
94
+ >>> agent = C51Agent(cfg=cfg, model=model)
95
+ or, if we want to reload the policy from a saved policy state dict:
96
+ >>> agent = C51Agent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
97
+ Make sure that the configuration is consistent with the saved policy state dict.
98
+ """
99
+
100
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101
+
102
+ if cfg is not None and not isinstance(cfg, EasyDict):
103
+ cfg = EasyDict(cfg)
104
+
105
+ if env_id is not None:
106
+ assert env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
107
+ C51Agent.supported_env_list
108
+ )
109
+ if cfg is None:
110
+ cfg = supported_env_cfg[env_id]
111
+ else:
112
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113
+ else:
114
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115
+ assert cfg.env.env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
116
+ C51Agent.supported_env_list
117
+ )
118
+ default_policy_config = EasyDict({"policy": C51Policy.default_config()})
119
+ default_policy_config.update(cfg)
120
+ cfg = default_policy_config
121
+
122
+ if exp_name is not None:
123
+ cfg.exp_name = exp_name
124
+ self.cfg = compile_config(cfg, policy=C51Policy)
125
+ self.exp_name = self.cfg.exp_name
126
+ if env is None:
127
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128
+ else:
129
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130
+ self.env = env
131
+
132
+ logging.getLogger().setLevel(logging.INFO)
133
+ self.seed = seed
134
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135
+ if not os.path.exists(self.exp_name):
136
+ os.makedirs(self.exp_name)
137
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
138
+ if model is None:
139
+ model = C51DQN(**self.cfg.policy.model)
140
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141
+ self.policy = C51Policy(self.cfg.policy, model=model)
142
+ if policy_state_dict is not None:
143
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
144
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145
+
146
+ def train(
147
+ self,
148
+ step: int = int(1e7),
149
+ collector_env_num: int = None,
150
+ evaluator_env_num: int = None,
151
+ n_iter_save_ckpt: int = 1000,
152
+ context: Optional[str] = None,
153
+ debug: bool = False,
154
+ wandb_sweep: bool = False,
155
+ ) -> TrainingReturn:
156
+ """
157
+ Overview:
158
+ Train the agent with C51 algorithm for ``step`` iterations with ``collector_env_num`` collector \
159
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160
+ recorded and saved by wandb.
161
+ Arguments:
162
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164
+ If not specified, it will be set according to the configuration.
165
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166
+ If not specified, it will be set according to the configuration.
167
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168
+ Default to 1000.
169
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
173
+ subprocess environment manager will be used.
174
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175
+ which is a hyper-parameter optimization process for seeking the best configurations. \
176
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
177
+ Returns:
178
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
179
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180
+ """
181
+
182
+ if debug:
183
+ logging.getLogger().setLevel(logging.DEBUG)
184
+ logging.debug(self.policy._model)
185
+ # define env and policy
186
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190
+
191
+ with task.start(ctx=OnlineRLContext()):
192
+ task.use(
193
+ interaction_evaluator(
194
+ self.cfg,
195
+ self.policy.eval_mode,
196
+ evaluator_env,
197
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198
+ )
199
+ )
200
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201
+ task.use(eps_greedy_handler(self.cfg))
202
+ task.use(
203
+ StepCollector(
204
+ self.cfg,
205
+ self.policy.collect_mode,
206
+ collector_env,
207
+ random_collect_size=self.cfg.policy.random_collect_size
208
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209
+ )
210
+ )
211
+ task.use(nstep_reward_enhancer(self.cfg))
212
+ task.use(data_pusher(self.cfg, self.buffer_))
213
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
214
+ task.use(
215
+ wandb_online_logger(
216
+ metric_list=self.policy._monitor_vars_learn(),
217
+ model=self.policy._model,
218
+ anonymous=True,
219
+ project_name=self.exp_name,
220
+ wandb_sweep=wandb_sweep,
221
+ )
222
+ )
223
+ task.use(termination_checker(max_env_step=step))
224
+ task.use(final_ctx_saver(name=self.exp_name))
225
+ task.run()
226
+
227
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
228
+
229
+ def deploy(
230
+ self,
231
+ enable_save_replay: bool = False,
232
+ concatenate_all_replay: bool = False,
233
+ replay_save_path: str = None,
234
+ seed: Optional[Union[int, List]] = None,
235
+ debug: bool = False
236
+ ) -> EvalReturn:
237
+ """
238
+ Overview:
239
+ Deploy the agent with C51 algorithm by interacting with the environment, during which the replay video \
240
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
241
+ Arguments:
242
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
243
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
244
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
245
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
246
+ the replay video of each episode will be saved separately.
247
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
248
+ If not specified, the video will be saved in ``exp_name/videos``.
249
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
250
+ Default to None. If not specified, ``self.seed`` will be used. \
251
+ If ``seed`` is an integer, the agent will be deployed once. \
252
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
253
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
254
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
255
+ subprocess environment manager will be used.
256
+ Returns:
257
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
258
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
259
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
260
+ """
261
+
262
+ if debug:
263
+ logging.getLogger().setLevel(logging.DEBUG)
264
+ # define env and policy
265
+ env = self.env.clone(caller='evaluator')
266
+
267
+ if seed is not None and isinstance(seed, int):
268
+ seeds = [seed]
269
+ elif seed is not None and isinstance(seed, list):
270
+ seeds = seed
271
+ else:
272
+ seeds = [self.seed]
273
+
274
+ returns = []
275
+ images = []
276
+ if enable_save_replay:
277
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
278
+ env.enable_save_replay(replay_path=replay_save_path)
279
+ else:
280
+ logging.warning('No video would be generated during the deploy.')
281
+ if concatenate_all_replay:
282
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
283
+ concatenate_all_replay = False
284
+
285
+ def single_env_forward_wrapper(forward_fn, cuda=True):
286
+
287
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
288
+
289
+ def _forward(obs):
290
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
291
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
292
+ if cuda and torch.cuda.is_available():
293
+ obs = obs.cuda()
294
+ action = forward_fn(obs)["action"]
295
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
296
+ action = action.squeeze(0).detach().cpu().numpy()
297
+ return action
298
+
299
+ return _forward
300
+
301
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
302
+
303
+ # reset first to make sure the env is in the initial state
304
+ # env will be reset again in the main loop
305
+ env.reset()
306
+
307
+ for seed in seeds:
308
+ env.seed(seed, dynamic_seed=False)
309
+ return_ = 0.
310
+ step = 0
311
+ obs = env.reset()
312
+ images.append(render(env)[None]) if concatenate_all_replay else None
313
+ while True:
314
+ action = forward_fn(obs)
315
+ obs, rew, done, info = env.step(action)
316
+ images.append(render(env)[None]) if concatenate_all_replay else None
317
+ return_ += rew
318
+ step += 1
319
+ if done:
320
+ break
321
+ logging.info(f'C51 deploy is finished, final episode return with {step} steps is: {return_}')
322
+ returns.append(return_)
323
+
324
+ env.close()
325
+
326
+ if concatenate_all_replay:
327
+ images = np.concatenate(images, axis=0)
328
+ import imageio
329
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
330
+
331
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
332
+
333
+ def collect_data(
334
+ self,
335
+ env_num: int = 8,
336
+ save_data_path: Optional[str] = None,
337
+ n_sample: Optional[int] = None,
338
+ n_episode: Optional[int] = None,
339
+ context: Optional[str] = None,
340
+ debug: bool = False
341
+ ) -> None:
342
+ """
343
+ Overview:
344
+ Collect data with C51 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
345
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
346
+ ``exp_name/demo_data``.
347
+ Arguments:
348
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
349
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
350
+ If not specified, the data will be saved in ``exp_name/demo_data``.
351
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
352
+ If not specified, ``n_episode`` must be specified.
353
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
354
+ If not specified, ``n_sample`` must be specified.
355
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
356
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
357
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
358
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
359
+ subprocess environment manager will be used.
360
+ """
361
+
362
+ if debug:
363
+ logging.getLogger().setLevel(logging.DEBUG)
364
+ if n_episode is not None:
365
+ raise NotImplementedError
366
+ # define env and policy
367
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
368
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
369
+
370
+ if save_data_path is None:
371
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
372
+
373
+ # main execution task
374
+ with task.start(ctx=OnlineRLContext()):
375
+ task.use(
376
+ StepCollector(
377
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
378
+ )
379
+ )
380
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
381
+ task.run(max_step=1)
382
+ logging.info(
383
+ f'C51 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
384
+ )
385
+
386
+ def batch_evaluate(
387
+ self,
388
+ env_num: int = 4,
389
+ n_evaluator_episode: int = 4,
390
+ context: Optional[str] = None,
391
+ debug: bool = False
392
+ ) -> EvalReturn:
393
+ """
394
+ Overview:
395
+ Evaluate the agent with C51 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
396
+ environments. The evaluation result will be returned.
397
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
398
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
399
+ will only create one evaluator environment to evaluate the agent and save the replay video.
400
+ Arguments:
401
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
402
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
403
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
404
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
405
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
406
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
407
+ subprocess environment manager will be used.
408
+ Returns:
409
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
410
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
411
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
412
+ """
413
+
414
+ if debug:
415
+ logging.getLogger().setLevel(logging.DEBUG)
416
+ # define env and policy
417
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
418
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
419
+
420
+ # reset first to make sure the env is in the initial state
421
+ # env will be reset again in the main loop
422
+ env.launch()
423
+ env.reset()
424
+
425
+ evaluate_cfg = self.cfg
426
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
427
+
428
+ # main execution task
429
+ with task.start(ctx=OnlineRLContext()):
430
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
431
+ task.run(max_step=1)
432
+
433
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
434
+
435
+ @property
436
+ def best(self) -> 'C51Agent':
437
+ """
438
+ Overview:
439
+ Load the best model from the checkpoint directory, \
440
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
441
+ The return value is the agent with the best model.
442
+ Returns:
443
+ - (:obj:`C51Agent`): The agent with the best model.
444
+ Examples:
445
+ >>> agent = C51Agent(env_id='LunarLander-v2')
446
+ >>> agent.train()
447
+ >>> agent = agent.best
448
+
449
+ .. note::
450
+ The best model is the model with the highest evaluation return. If this method is called, the current \
451
+ model will be replaced by the best model.
452
+ """
453
+
454
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
455
+ # Load best model if it exists
456
+ if os.path.exists(best_model_file_path):
457
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
458
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
459
+ return self
DI-engine/ding/bonus/common.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+
4
+
5
+ @dataclass
6
+ class TrainingReturn:
7
+ '''
8
+ Attributions
9
+ wandb_url: The weight & biases (wandb) project url of the trainning experiment.
10
+ '''
11
+ wandb_url: str
12
+
13
+
14
+ @dataclass
15
+ class EvalReturn:
16
+ '''
17
+ Attributions
18
+ eval_value: The mean of evaluation return.
19
+ eval_value_std: The standard deviation of evaluation return.
20
+ '''
21
+ eval_value: np.float32
22
+ eval_value_std: np.float32
DI-engine/ding/bonus/config.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import os
3
+ import gym
4
+ from ding.envs import BaseEnv, DingEnvWrapper
5
+ from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
6
+ EvalEpisodeReturnWrapper, TransposeWrapper, TimeLimitWrapper, FlatObsWrapper, GymToGymnasiumWrapper
7
+ from ding.policy import PPOFPolicy
8
+
9
+
10
+ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
11
+ if algorithm == 'PPOF':
12
+ cfg = PPOFPolicy.default_config()
13
+ if env_id == 'LunarLander-v2':
14
+ cfg.n_sample = 512
15
+ cfg.value_norm = 'popart'
16
+ cfg.entropy_weight = 1e-3
17
+ elif env_id == 'LunarLanderContinuous-v2':
18
+ cfg.action_space = 'continuous'
19
+ cfg.n_sample = 400
20
+ elif env_id == 'BipedalWalker-v3':
21
+ cfg.learning_rate = 1e-3
22
+ cfg.action_space = 'continuous'
23
+ cfg.n_sample = 1024
24
+ elif env_id == 'Pendulum-v1':
25
+ cfg.action_space = 'continuous'
26
+ cfg.n_sample = 400
27
+ elif env_id == 'acrobot':
28
+ cfg.learning_rate = 1e-4
29
+ cfg.n_sample = 400
30
+ elif env_id == 'rocket_landing':
31
+ cfg.n_sample = 2048
32
+ cfg.adv_norm = False
33
+ cfg.model = dict(
34
+ encoder_hidden_size_list=[64, 64, 128],
35
+ actor_head_hidden_size=128,
36
+ critic_head_hidden_size=128,
37
+ )
38
+ elif env_id == 'drone_fly':
39
+ cfg.action_space = 'continuous'
40
+ cfg.adv_norm = False
41
+ cfg.epoch_per_collect = 5
42
+ cfg.learning_rate = 5e-5
43
+ cfg.n_sample = 640
44
+ elif env_id == 'hybrid_moving':
45
+ cfg.action_space = 'hybrid'
46
+ cfg.n_sample = 3200
47
+ cfg.entropy_weight = 0.03
48
+ cfg.batch_size = 320
49
+ cfg.adv_norm = False
50
+ cfg.model = dict(
51
+ encoder_hidden_size_list=[256, 128, 64, 64],
52
+ sigma_type='fixed',
53
+ fixed_sigma_value=0.3,
54
+ bound_type='tanh',
55
+ )
56
+ elif env_id == 'evogym_carrier':
57
+ cfg.action_space = 'continuous'
58
+ cfg.n_sample = 2048
59
+ cfg.batch_size = 256
60
+ cfg.epoch_per_collect = 10
61
+ cfg.learning_rate = 3e-3
62
+ elif env_id == 'mario':
63
+ cfg.n_sample = 256
64
+ cfg.batch_size = 64
65
+ cfg.epoch_per_collect = 2
66
+ cfg.learning_rate = 1e-3
67
+ cfg.model = dict(
68
+ encoder_hidden_size_list=[64, 64, 128],
69
+ critic_head_hidden_size=128,
70
+ actor_head_hidden_size=128,
71
+ )
72
+ elif env_id == 'di_sheep':
73
+ cfg.n_sample = 3200
74
+ cfg.batch_size = 320
75
+ cfg.epoch_per_collect = 10
76
+ cfg.learning_rate = 3e-4
77
+ cfg.adv_norm = False
78
+ cfg.entropy_weight = 0.001
79
+ elif env_id == 'procgen_bigfish':
80
+ cfg.n_sample = 16384
81
+ cfg.batch_size = 16384
82
+ cfg.epoch_per_collect = 10
83
+ cfg.learning_rate = 5e-4
84
+ cfg.model = dict(
85
+ encoder_hidden_size_list=[64, 128, 256],
86
+ critic_head_hidden_size=256,
87
+ actor_head_hidden_size=256,
88
+ )
89
+ elif env_id in ['KangarooNoFrameskip-v4', 'BowlingNoFrameskip-v4']:
90
+ cfg.n_sample = 1024
91
+ cfg.batch_size = 128
92
+ cfg.epoch_per_collect = 10
93
+ cfg.learning_rate = 0.0001
94
+ cfg.model = dict(
95
+ encoder_hidden_size_list=[32, 64, 64, 128],
96
+ actor_head_hidden_size=128,
97
+ critic_head_hidden_size=128,
98
+ critic_head_layer_num=2,
99
+ )
100
+ elif env_id == 'PongNoFrameskip-v4':
101
+ cfg.n_sample = 3200
102
+ cfg.batch_size = 320
103
+ cfg.epoch_per_collect = 10
104
+ cfg.learning_rate = 3e-4
105
+ cfg.model = dict(
106
+ encoder_hidden_size_list=[64, 64, 128],
107
+ actor_head_hidden_size=128,
108
+ critic_head_hidden_size=128,
109
+ )
110
+ elif env_id == 'SpaceInvadersNoFrameskip-v4':
111
+ cfg.n_sample = 320
112
+ cfg.batch_size = 320
113
+ cfg.epoch_per_collect = 1
114
+ cfg.learning_rate = 1e-3
115
+ cfg.entropy_weight = 0.01
116
+ cfg.lr_scheduler = (2000, 0.1)
117
+ cfg.model = dict(
118
+ encoder_hidden_size_list=[64, 64, 128],
119
+ actor_head_hidden_size=128,
120
+ critic_head_hidden_size=128,
121
+ )
122
+ elif env_id == 'QbertNoFrameskip-v4':
123
+ cfg.n_sample = 3200
124
+ cfg.batch_size = 320
125
+ cfg.epoch_per_collect = 10
126
+ cfg.learning_rate = 5e-4
127
+ cfg.lr_scheduler = (1000, 0.1)
128
+ cfg.model = dict(
129
+ encoder_hidden_size_list=[64, 64, 128],
130
+ actor_head_hidden_size=128,
131
+ critic_head_hidden_size=128,
132
+ )
133
+ elif env_id == 'minigrid_fourroom':
134
+ cfg.n_sample = 3200
135
+ cfg.batch_size = 320
136
+ cfg.learning_rate = 3e-4
137
+ cfg.epoch_per_collect = 10
138
+ cfg.entropy_weight = 0.001
139
+ elif env_id == 'metadrive':
140
+ cfg.learning_rate = 3e-4
141
+ cfg.action_space = 'continuous'
142
+ cfg.entropy_weight = 0.001
143
+ cfg.n_sample = 3000
144
+ cfg.epoch_per_collect = 10
145
+ cfg.learning_rate = 0.0001
146
+ cfg.model = dict(
147
+ encoder_hidden_size_list=[32, 64, 64, 128],
148
+ actor_head_hidden_size=128,
149
+ critic_head_hidden_size=128,
150
+ critic_head_layer_num=2,
151
+ )
152
+ elif env_id == 'Hopper-v3':
153
+ cfg.action_space = "continuous"
154
+ cfg.n_sample = 3200
155
+ cfg.batch_size = 320
156
+ cfg.epoch_per_collect = 10
157
+ cfg.learning_rate = 3e-4
158
+ elif env_id == 'HalfCheetah-v3':
159
+ cfg.action_space = "continuous"
160
+ cfg.n_sample = 3200
161
+ cfg.batch_size = 320
162
+ cfg.epoch_per_collect = 10
163
+ cfg.learning_rate = 3e-4
164
+ elif env_id == 'Walker2d-v3':
165
+ cfg.action_space = "continuous"
166
+ cfg.n_sample = 3200
167
+ cfg.batch_size = 320
168
+ cfg.epoch_per_collect = 10
169
+ cfg.learning_rate = 3e-4
170
+ else:
171
+ raise KeyError("not supported env type: {}".format(env_id))
172
+ else:
173
+ raise KeyError("not supported algorithm type: {}".format(algorithm))
174
+
175
+ return cfg
176
+
177
+
178
+ def get_instance_env(env_id: str) -> BaseEnv:
179
+ if env_id == 'LunarLander-v2':
180
+ return DingEnvWrapper(gym.make('LunarLander-v2'))
181
+ elif env_id == 'LunarLanderContinuous-v2':
182
+ return DingEnvWrapper(gym.make('LunarLanderContinuous-v2', continuous=True))
183
+ elif env_id == 'BipedalWalker-v3':
184
+ return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True, 'rew_clip': True})
185
+ elif env_id == 'Pendulum-v1':
186
+ return DingEnvWrapper(gym.make('Pendulum-v1'), cfg={'act_scale': True})
187
+ elif env_id == 'acrobot':
188
+ return DingEnvWrapper(gym.make('Acrobot-v1'))
189
+ elif env_id == 'rocket_landing':
190
+ from dizoo.rocket.envs import RocketEnv
191
+ cfg = EasyDict({
192
+ 'task': 'landing',
193
+ 'max_steps': 800,
194
+ })
195
+ return RocketEnv(cfg)
196
+ elif env_id == 'drone_fly':
197
+ from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
198
+ cfg = EasyDict({
199
+ 'env_id': 'flythrugate-aviary-v0',
200
+ 'action_type': 'VEL',
201
+ })
202
+ return GymPybulletDronesEnv(cfg)
203
+ elif env_id == 'hybrid_moving':
204
+ import gym_hybrid
205
+ return DingEnvWrapper(gym.make('Moving-v0'))
206
+ elif env_id == 'evogym_carrier':
207
+ import evogym.envs
208
+ from evogym import sample_robot, WorldObject
209
+ path = os.path.join(os.path.dirname(__file__), '../../dizoo/evogym/envs/world_data/carry_bot.json')
210
+ robot_object = WorldObject.from_json(path)
211
+ body = robot_object.get_structure()
212
+ return DingEnvWrapper(
213
+ gym.make('Carrier-v0', body=body),
214
+ cfg={
215
+ 'env_wrapper': [
216
+ lambda env: TimeLimitWrapper(env, max_limit=300),
217
+ lambda env: EvalEpisodeReturnWrapper(env),
218
+ ]
219
+ }
220
+ )
221
+ elif env_id == 'mario':
222
+ import gym_super_mario_bros
223
+ from nes_py.wrappers import JoypadSpace
224
+ return DingEnvWrapper(
225
+ JoypadSpace(gym_super_mario_bros.make("SuperMarioBros-1-1-v1"), [["right"], ["right", "A"]]),
226
+ cfg={
227
+ 'env_wrapper': [
228
+ lambda env: MaxAndSkipWrapper(env, skip=4),
229
+ lambda env: WarpFrameWrapper(env, size=84),
230
+ lambda env: ScaledFloatFrameWrapper(env),
231
+ lambda env: FrameStackWrapper(env, n_frames=4),
232
+ lambda env: TimeLimitWrapper(env, max_limit=200),
233
+ lambda env: EvalEpisodeReturnWrapper(env),
234
+ ]
235
+ }
236
+ )
237
+ elif env_id == 'di_sheep':
238
+ from sheep_env import SheepEnv
239
+ return DingEnvWrapper(SheepEnv(level=9))
240
+ elif env_id == 'procgen_bigfish':
241
+ return DingEnvWrapper(
242
+ gym.make('procgen:procgen-bigfish-v0', start_level=0, num_levels=1),
243
+ cfg={
244
+ 'env_wrapper': [
245
+ lambda env: TransposeWrapper(env),
246
+ lambda env: ScaledFloatFrameWrapper(env),
247
+ lambda env: EvalEpisodeReturnWrapper(env),
248
+ ]
249
+ },
250
+ seed_api=False,
251
+ )
252
+ elif env_id == 'Hopper-v3':
253
+ cfg = EasyDict(
254
+ env_id='Hopper-v3',
255
+ env_wrapper='mujoco_default',
256
+ act_scale=True,
257
+ rew_clip=True,
258
+ )
259
+ return DingEnvWrapper(gym.make('Hopper-v3'), cfg=cfg)
260
+ elif env_id == 'HalfCheetah-v3':
261
+ cfg = EasyDict(
262
+ env_id='HalfCheetah-v3',
263
+ env_wrapper='mujoco_default',
264
+ act_scale=True,
265
+ rew_clip=True,
266
+ )
267
+ return DingEnvWrapper(gym.make('HalfCheetah-v3'), cfg=cfg)
268
+ elif env_id == 'Walker2d-v3':
269
+ cfg = EasyDict(
270
+ env_id='Walker2d-v3',
271
+ env_wrapper='mujoco_default',
272
+ act_scale=True,
273
+ rew_clip=True,
274
+ )
275
+ return DingEnvWrapper(gym.make('Walker2d-v3'), cfg=cfg)
276
+
277
+ elif env_id in [
278
+ 'BowlingNoFrameskip-v4',
279
+ 'BreakoutNoFrameskip-v4',
280
+ 'GopherNoFrameskip-v4'
281
+ 'KangarooNoFrameskip-v4',
282
+ 'PongNoFrameskip-v4',
283
+ 'QbertNoFrameskip-v4',
284
+ 'SpaceInvadersNoFrameskip-v4',
285
+ ]:
286
+
287
+ cfg = EasyDict({
288
+ 'env_id': env_id,
289
+ 'env_wrapper': 'atari_default',
290
+ })
291
+ ding_env_atari = DingEnvWrapper(gym.make(env_id), cfg=cfg)
292
+ return ding_env_atari
293
+ elif env_id == 'minigrid_fourroom':
294
+ import gymnasium
295
+ return DingEnvWrapper(
296
+ gymnasium.make('MiniGrid-FourRooms-v0'),
297
+ cfg={
298
+ 'env_wrapper': [
299
+ lambda env: GymToGymnasiumWrapper(env),
300
+ lambda env: FlatObsWrapper(env),
301
+ lambda env: TimeLimitWrapper(env, max_limit=300),
302
+ lambda env: EvalEpisodeReturnWrapper(env),
303
+ ]
304
+ }
305
+ )
306
+ elif env_id == 'metadrive':
307
+ from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
308
+ from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
309
+ cfg = dict(
310
+ map='XSOS',
311
+ horizon=4000,
312
+ out_of_road_penalty=40.0,
313
+ crash_vehicle_penalty=40.0,
314
+ out_of_route_done=True,
315
+ )
316
+ cfg = EasyDict(cfg)
317
+ return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
318
+ else:
319
+ raise KeyError("not supported env type: {}".format(env_id))
320
+
321
+
322
+ def get_hybrid_shape(action_space) -> EasyDict:
323
+ return EasyDict({
324
+ 'action_type_shape': action_space[0].n,
325
+ 'action_args_shape': action_space[1].shape,
326
+ })
DI-engine/ding/bonus/ddpg.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11
+ OffPolicyLearner, final_ctx_saver
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import DDPGPolicy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import ContinuousQAC
19
+ from ding.data import DequeBuffer
20
+ from ding.bonus.common import TrainingReturn, EvalReturn
21
+ from ding.config.example.DDPG import supported_env_cfg
22
+ from ding.config.example.DDPG import supported_env
23
+
24
+
25
+ class DDPGAgent:
26
+ """
27
+ Overview:
28
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29
+ Deep Deterministic Policy Gradient(DDPG).
30
+ For more information about the system design of RL agent, please refer to \
31
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
32
+ Interface:
33
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34
+ """
35
+ supported_env_list = list(supported_env_cfg.keys())
36
+ """
37
+ Overview:
38
+ List of supported envs.
39
+ Examples:
40
+ >>> from ding.bonus.ddpg import DDPGAgent
41
+ >>> print(DDPGAgent.supported_env_list)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_id: str = None,
47
+ env: BaseEnv = None,
48
+ seed: int = 0,
49
+ exp_name: str = None,
50
+ model: Optional[torch.nn.Module] = None,
51
+ cfg: Optional[Union[EasyDict, dict]] = None,
52
+ policy_state_dict: str = None,
53
+ ) -> None:
54
+ """
55
+ Overview:
56
+ Initialize agent for DDPG algorithm.
57
+ Arguments:
58
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
67
+ Default to 0.
68
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70
+ - model (:obj:`torch.nn.Module`): The model of DDPG algorithm, which should be an instance of class \
71
+ :class:`ding.model.ContinuousQAC`. \
72
+ If not specified, a default model will be generated according to the configuration.
73
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of DDPG algorithm, which is a dict. \
74
+ Default to None. If not specified, the default configuration will be used. \
75
+ The default configuration can be found in ``ding/config/example/DDPG/gym_lunarlander_v2.py``.
76
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77
+ If specified, the policy will be loaded from this file. Default to None.
78
+
79
+ .. note::
80
+ An RL Agent Instance can be initialized in two basic ways. \
81
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
82
+ and we want to train an agent with DDPG algorithm with default configuration. \
83
+ Then we can initialize the agent in the following ways:
84
+ >>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
85
+ or, if we want can specify the env_id in the configuration:
86
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
87
+ >>> agent = DDPGAgent(cfg=cfg)
88
+ There are also other arguments to specify the agent when initializing.
89
+ For example, if we want to specify the environment instance:
90
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
91
+ >>> agent = DDPGAgent(cfg=cfg, env=env)
92
+ or, if we want to specify the model:
93
+ >>> model = ContinuousQAC(**cfg.policy.model)
94
+ >>> agent = DDPGAgent(cfg=cfg, model=model)
95
+ or, if we want to reload the policy from a saved policy state dict:
96
+ >>> agent = DDPGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
97
+ Make sure that the configuration is consistent with the saved policy state dict.
98
+ """
99
+
100
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101
+
102
+ if cfg is not None and not isinstance(cfg, EasyDict):
103
+ cfg = EasyDict(cfg)
104
+
105
+ if env_id is not None:
106
+ assert env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
107
+ DDPGAgent.supported_env_list
108
+ )
109
+ if cfg is None:
110
+ cfg = supported_env_cfg[env_id]
111
+ else:
112
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113
+ else:
114
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115
+ assert cfg.env.env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
116
+ DDPGAgent.supported_env_list
117
+ )
118
+ default_policy_config = EasyDict({"policy": DDPGPolicy.default_config()})
119
+ default_policy_config.update(cfg)
120
+ cfg = default_policy_config
121
+
122
+ if exp_name is not None:
123
+ cfg.exp_name = exp_name
124
+ self.cfg = compile_config(cfg, policy=DDPGPolicy)
125
+ self.exp_name = self.cfg.exp_name
126
+ if env is None:
127
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128
+ else:
129
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130
+ self.env = env
131
+
132
+ logging.getLogger().setLevel(logging.INFO)
133
+ self.seed = seed
134
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135
+ if not os.path.exists(self.exp_name):
136
+ os.makedirs(self.exp_name)
137
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
138
+ if model is None:
139
+ model = ContinuousQAC(**self.cfg.policy.model)
140
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141
+ self.policy = DDPGPolicy(self.cfg.policy, model=model)
142
+ if policy_state_dict is not None:
143
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
144
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145
+
146
+ def train(
147
+ self,
148
+ step: int = int(1e7),
149
+ collector_env_num: int = None,
150
+ evaluator_env_num: int = None,
151
+ n_iter_log_show: int = 500,
152
+ n_iter_save_ckpt: int = 1000,
153
+ context: Optional[str] = None,
154
+ debug: bool = False,
155
+ wandb_sweep: bool = False,
156
+ ) -> TrainingReturn:
157
+ """
158
+ Overview:
159
+ Train the agent with DDPG algorithm for ``step`` iterations with ``collector_env_num`` collector \
160
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
161
+ recorded and saved by wandb.
162
+ Arguments:
163
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
164
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
165
+ If not specified, it will be set according to the configuration.
166
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
167
+ If not specified, it will be set according to the configuration.
168
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
169
+ Default to 1000.
170
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
171
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
172
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
173
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
174
+ subprocess environment manager will be used.
175
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
176
+ which is a hyper-parameter optimization process for seeking the best configurations. \
177
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
178
+ Returns:
179
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
180
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
181
+ """
182
+
183
+ if debug:
184
+ logging.getLogger().setLevel(logging.DEBUG)
185
+ logging.debug(self.policy._model)
186
+ # define env and policy
187
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
188
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
189
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
190
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
191
+
192
+ with task.start(ctx=OnlineRLContext()):
193
+ task.use(
194
+ interaction_evaluator(
195
+ self.cfg,
196
+ self.policy.eval_mode,
197
+ evaluator_env,
198
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
199
+ )
200
+ )
201
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
202
+ task.use(
203
+ StepCollector(
204
+ self.cfg,
205
+ self.policy.collect_mode,
206
+ collector_env,
207
+ random_collect_size=self.cfg.policy.random_collect_size
208
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209
+ )
210
+ )
211
+ task.use(data_pusher(self.cfg, self.buffer_))
212
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
213
+ task.use(
214
+ wandb_online_logger(
215
+ metric_list=self.policy._monitor_vars_learn(),
216
+ model=self.policy._model,
217
+ anonymous=True,
218
+ project_name=self.exp_name,
219
+ wandb_sweep=wandb_sweep,
220
+ )
221
+ )
222
+ task.use(termination_checker(max_env_step=step))
223
+ task.use(final_ctx_saver(name=self.exp_name))
224
+ task.run()
225
+
226
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
227
+
228
+ def deploy(
229
+ self,
230
+ enable_save_replay: bool = False,
231
+ concatenate_all_replay: bool = False,
232
+ replay_save_path: str = None,
233
+ seed: Optional[Union[int, List]] = None,
234
+ debug: bool = False
235
+ ) -> EvalReturn:
236
+ """
237
+ Overview:
238
+ Deploy the agent with DDPG algorithm by interacting with the environment, during which the replay video \
239
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
240
+ Arguments:
241
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
242
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
243
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
244
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
245
+ the replay video of each episode will be saved separately.
246
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
247
+ If not specified, the video will be saved in ``exp_name/videos``.
248
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
249
+ Default to None. If not specified, ``self.seed`` will be used. \
250
+ If ``seed`` is an integer, the agent will be deployed once. \
251
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
252
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
253
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
254
+ subprocess environment manager will be used.
255
+ Returns:
256
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
257
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
258
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
259
+ """
260
+
261
+ if debug:
262
+ logging.getLogger().setLevel(logging.DEBUG)
263
+ # define env and policy
264
+ env = self.env.clone(caller='evaluator')
265
+
266
+ if seed is not None and isinstance(seed, int):
267
+ seeds = [seed]
268
+ elif seed is not None and isinstance(seed, list):
269
+ seeds = seed
270
+ else:
271
+ seeds = [self.seed]
272
+
273
+ returns = []
274
+ images = []
275
+ if enable_save_replay:
276
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
277
+ env.enable_save_replay(replay_path=replay_save_path)
278
+ else:
279
+ logging.warning('No video would be generated during the deploy.')
280
+ if concatenate_all_replay:
281
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
282
+ concatenate_all_replay = False
283
+
284
+ def single_env_forward_wrapper(forward_fn, cuda=True):
285
+
286
+ def _forward(obs):
287
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
288
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
289
+ if cuda and torch.cuda.is_available():
290
+ obs = obs.cuda()
291
+ action = forward_fn(obs, mode='compute_actor')["action"]
292
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
293
+ action = action.squeeze(0).detach().cpu().numpy()
294
+ return action
295
+
296
+ return _forward
297
+
298
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
299
+
300
+ # reset first to make sure the env is in the initial state
301
+ # env will be reset again in the main loop
302
+ env.reset()
303
+
304
+ for seed in seeds:
305
+ env.seed(seed, dynamic_seed=False)
306
+ return_ = 0.
307
+ step = 0
308
+ obs = env.reset()
309
+ images.append(render(env)[None]) if concatenate_all_replay else None
310
+ while True:
311
+ action = forward_fn(obs)
312
+ obs, rew, done, info = env.step(action)
313
+ images.append(render(env)[None]) if concatenate_all_replay else None
314
+ return_ += rew
315
+ step += 1
316
+ if done:
317
+ break
318
+ logging.info(f'DDPG deploy is finished, final episode return with {step} steps is: {return_}')
319
+ returns.append(return_)
320
+
321
+ env.close()
322
+
323
+ if concatenate_all_replay:
324
+ images = np.concatenate(images, axis=0)
325
+ import imageio
326
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
327
+
328
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
329
+
330
+ def collect_data(
331
+ self,
332
+ env_num: int = 8,
333
+ save_data_path: Optional[str] = None,
334
+ n_sample: Optional[int] = None,
335
+ n_episode: Optional[int] = None,
336
+ context: Optional[str] = None,
337
+ debug: bool = False
338
+ ) -> None:
339
+ """
340
+ Overview:
341
+ Collect data with DDPG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
342
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
343
+ ``exp_name/demo_data``.
344
+ Arguments:
345
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
346
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
347
+ If not specified, the data will be saved in ``exp_name/demo_data``.
348
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
349
+ If not specified, ``n_episode`` must be specified.
350
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
351
+ If not specified, ``n_sample`` must be specified.
352
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
353
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
354
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
355
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
356
+ subprocess environment manager will be used.
357
+ """
358
+
359
+ if debug:
360
+ logging.getLogger().setLevel(logging.DEBUG)
361
+ if n_episode is not None:
362
+ raise NotImplementedError
363
+ # define env and policy
364
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
365
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
366
+
367
+ if save_data_path is None:
368
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
369
+
370
+ # main execution task
371
+ with task.start(ctx=OnlineRLContext()):
372
+ task.use(
373
+ StepCollector(
374
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
375
+ )
376
+ )
377
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
378
+ task.run(max_step=1)
379
+ logging.info(
380
+ f'DDPG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
381
+ )
382
+
383
+ def batch_evaluate(
384
+ self,
385
+ env_num: int = 4,
386
+ n_evaluator_episode: int = 4,
387
+ context: Optional[str] = None,
388
+ debug: bool = False
389
+ ) -> EvalReturn:
390
+ """
391
+ Overview:
392
+ Evaluate the agent with DDPG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
393
+ environments. The evaluation result will be returned.
394
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
395
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
396
+ will only create one evaluator environment to evaluate the agent and save the replay video.
397
+ Arguments:
398
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
399
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
400
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
401
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
402
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
403
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
404
+ subprocess environment manager will be used.
405
+ Returns:
406
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
407
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
408
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
409
+ """
410
+
411
+ if debug:
412
+ logging.getLogger().setLevel(logging.DEBUG)
413
+ # define env and policy
414
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
415
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
416
+
417
+ # reset first to make sure the env is in the initial state
418
+ # env will be reset again in the main loop
419
+ env.launch()
420
+ env.reset()
421
+
422
+ evaluate_cfg = self.cfg
423
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
424
+
425
+ # main execution task
426
+ with task.start(ctx=OnlineRLContext()):
427
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
428
+ task.run(max_step=1)
429
+
430
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
431
+
432
+ @property
433
+ def best(self) -> 'DDPGAgent':
434
+ """
435
+ Overview:
436
+ Load the best model from the checkpoint directory, \
437
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
438
+ The return value is the agent with the best model.
439
+ Returns:
440
+ - (:obj:`DDPGAgent`): The agent with the best model.
441
+ Examples:
442
+ >>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
443
+ >>> agent.train()
444
+ >>> agent = agent.best
445
+
446
+ .. note::
447
+ The best model is the model with the highest evaluation return. If this method is called, the current \
448
+ model will be replaced by the best model.
449
+ """
450
+
451
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
452
+ # Load best model if it exists
453
+ if os.path.exists(best_model_file_path):
454
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
455
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
456
+ return self
DI-engine/ding/bonus/dqn.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11
+ OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import DQNPolicy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import DQN
19
+ from ding.model import model_wrap
20
+ from ding.data import DequeBuffer
21
+ from ding.bonus.common import TrainingReturn, EvalReturn
22
+ from ding.config.example.DQN import supported_env_cfg
23
+ from ding.config.example.DQN import supported_env
24
+
25
+
26
+ class DQNAgent:
27
+ """
28
+ Overview:
29
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Deep Q-Learning(DQN).
30
+ For more information about the system design of RL agent, please refer to \
31
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
32
+ Interface:
33
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34
+ """
35
+ supported_env_list = list(supported_env_cfg.keys())
36
+ """
37
+ Overview:
38
+ List of supported envs.
39
+ Examples:
40
+ >>> from ding.bonus.dqn import DQNAgent
41
+ >>> print(DQNAgent.supported_env_list)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_id: str = None,
47
+ env: BaseEnv = None,
48
+ seed: int = 0,
49
+ exp_name: str = None,
50
+ model: Optional[torch.nn.Module] = None,
51
+ cfg: Optional[Union[EasyDict, dict]] = None,
52
+ policy_state_dict: str = None,
53
+ ) -> None:
54
+ """
55
+ Overview:
56
+ Initialize agent for DQN algorithm.
57
+ Arguments:
58
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
67
+ Default to 0.
68
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70
+ - model (:obj:`torch.nn.Module`): The model of DQN algorithm, which should be an instance of class \
71
+ :class:`ding.model.DQN`. \
72
+ If not specified, a default model will be generated according to the configuration.
73
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of DQN algorithm, which is a dict. \
74
+ Default to None. If not specified, the default configuration will be used. \
75
+ The default configuration can be found in ``ding/config/example/DQN/gym_lunarlander_v2.py``.
76
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77
+ If specified, the policy will be loaded from this file. Default to None.
78
+
79
+ .. note::
80
+ An RL Agent Instance can be initialized in two basic ways. \
81
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
82
+ and we want to train an agent with DQN algorithm with default configuration. \
83
+ Then we can initialize the agent in the following ways:
84
+ >>> agent = DQNAgent(env_id='LunarLander-v2')
85
+ or, if we want can specify the env_id in the configuration:
86
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
87
+ >>> agent = DQNAgent(cfg=cfg)
88
+ There are also other arguments to specify the agent when initializing.
89
+ For example, if we want to specify the environment instance:
90
+ >>> env = CustomizedEnv('LunarLander-v2')
91
+ >>> agent = DQNAgent(cfg=cfg, env=env)
92
+ or, if we want to specify the model:
93
+ >>> model = DQN(**cfg.policy.model)
94
+ >>> agent = DQNAgent(cfg=cfg, model=model)
95
+ or, if we want to reload the policy from a saved policy state dict:
96
+ >>> agent = DQNAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
97
+ Make sure that the configuration is consistent with the saved policy state dict.
98
+ """
99
+
100
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101
+
102
+ if cfg is not None and not isinstance(cfg, EasyDict):
103
+ cfg = EasyDict(cfg)
104
+
105
+ if env_id is not None:
106
+ assert env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
107
+ DQNAgent.supported_env_list
108
+ )
109
+ if cfg is None:
110
+ cfg = supported_env_cfg[env_id]
111
+ else:
112
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113
+ else:
114
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115
+ assert cfg.env.env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
116
+ DQNAgent.supported_env_list
117
+ )
118
+ default_policy_config = EasyDict({"policy": DQNPolicy.default_config()})
119
+ default_policy_config.update(cfg)
120
+ cfg = default_policy_config
121
+
122
+ if exp_name is not None:
123
+ cfg.exp_name = exp_name
124
+ self.cfg = compile_config(cfg, policy=DQNPolicy)
125
+ self.exp_name = self.cfg.exp_name
126
+ if env is None:
127
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128
+ else:
129
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130
+ self.env = env
131
+
132
+ logging.getLogger().setLevel(logging.INFO)
133
+ self.seed = seed
134
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135
+ if not os.path.exists(self.exp_name):
136
+ os.makedirs(self.exp_name)
137
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
138
+ if model is None:
139
+ model = DQN(**self.cfg.policy.model)
140
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141
+ self.policy = DQNPolicy(self.cfg.policy, model=model)
142
+ if policy_state_dict is not None:
143
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
144
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145
+
146
+ def train(
147
+ self,
148
+ step: int = int(1e7),
149
+ collector_env_num: int = None,
150
+ evaluator_env_num: int = None,
151
+ n_iter_save_ckpt: int = 1000,
152
+ context: Optional[str] = None,
153
+ debug: bool = False,
154
+ wandb_sweep: bool = False,
155
+ ) -> TrainingReturn:
156
+ """
157
+ Overview:
158
+ Train the agent with DQN algorithm for ``step`` iterations with ``collector_env_num`` collector \
159
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160
+ recorded and saved by wandb.
161
+ Arguments:
162
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164
+ If not specified, it will be set according to the configuration.
165
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166
+ If not specified, it will be set according to the configuration.
167
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168
+ Default to 1000.
169
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
173
+ subprocess environment manager will be used.
174
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175
+ which is a hyper-parameter optimization process for seeking the best configurations. \
176
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
177
+ Returns:
178
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
179
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180
+ """
181
+
182
+ if debug:
183
+ logging.getLogger().setLevel(logging.DEBUG)
184
+ logging.debug(self.policy._model)
185
+ # define env and policy
186
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190
+
191
+ with task.start(ctx=OnlineRLContext()):
192
+ task.use(
193
+ interaction_evaluator(
194
+ self.cfg,
195
+ self.policy.eval_mode,
196
+ evaluator_env,
197
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198
+ )
199
+ )
200
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201
+ task.use(eps_greedy_handler(self.cfg))
202
+ task.use(
203
+ StepCollector(
204
+ self.cfg,
205
+ self.policy.collect_mode,
206
+ collector_env,
207
+ random_collect_size=self.cfg.policy.random_collect_size
208
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209
+ )
210
+ )
211
+ if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
212
+ task.use(nstep_reward_enhancer(self.cfg))
213
+ task.use(data_pusher(self.cfg, self.buffer_))
214
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
215
+ task.use(
216
+ wandb_online_logger(
217
+ metric_list=self.policy._monitor_vars_learn(),
218
+ model=self.policy._model,
219
+ anonymous=True,
220
+ project_name=self.exp_name,
221
+ wandb_sweep=wandb_sweep,
222
+ )
223
+ )
224
+ task.use(termination_checker(max_env_step=step))
225
+ task.use(final_ctx_saver(name=self.exp_name))
226
+ task.run()
227
+
228
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
229
+
230
+ def deploy(
231
+ self,
232
+ enable_save_replay: bool = False,
233
+ concatenate_all_replay: bool = False,
234
+ replay_save_path: str = None,
235
+ seed: Optional[Union[int, List]] = None,
236
+ debug: bool = False
237
+ ) -> EvalReturn:
238
+ """
239
+ Overview:
240
+ Deploy the agent with DQN algorithm by interacting with the environment, during which the replay video \
241
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
242
+ Arguments:
243
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
244
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
245
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
246
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
247
+ the replay video of each episode will be saved separately.
248
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
249
+ If not specified, the video will be saved in ``exp_name/videos``.
250
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
251
+ Default to None. If not specified, ``self.seed`` will be used. \
252
+ If ``seed`` is an integer, the agent will be deployed once. \
253
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
254
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
255
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
256
+ subprocess environment manager will be used.
257
+ Returns:
258
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
259
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
260
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
261
+ """
262
+
263
+ if debug:
264
+ logging.getLogger().setLevel(logging.DEBUG)
265
+ # define env and policy
266
+ env = self.env.clone(caller='evaluator')
267
+
268
+ if seed is not None and isinstance(seed, int):
269
+ seeds = [seed]
270
+ elif seed is not None and isinstance(seed, list):
271
+ seeds = seed
272
+ else:
273
+ seeds = [self.seed]
274
+
275
+ returns = []
276
+ images = []
277
+ if enable_save_replay:
278
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
279
+ env.enable_save_replay(replay_path=replay_save_path)
280
+ else:
281
+ logging.warning('No video would be generated during the deploy.')
282
+ if concatenate_all_replay:
283
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
284
+ concatenate_all_replay = False
285
+
286
+ def single_env_forward_wrapper(forward_fn, cuda=True):
287
+
288
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
289
+
290
+ def _forward(obs):
291
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
292
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
293
+ if cuda and torch.cuda.is_available():
294
+ obs = obs.cuda()
295
+ action = forward_fn(obs)["action"]
296
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
297
+ action = action.squeeze(0).detach().cpu().numpy()
298
+ return action
299
+
300
+ return _forward
301
+
302
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
303
+
304
+ # reset first to make sure the env is in the initial state
305
+ # env will be reset again in the main loop
306
+ env.reset()
307
+
308
+ for seed in seeds:
309
+ env.seed(seed, dynamic_seed=False)
310
+ return_ = 0.
311
+ step = 0
312
+ obs = env.reset()
313
+ images.append(render(env)[None]) if concatenate_all_replay else None
314
+ while True:
315
+ action = forward_fn(obs)
316
+ obs, rew, done, info = env.step(action)
317
+ images.append(render(env)[None]) if concatenate_all_replay else None
318
+ return_ += rew
319
+ step += 1
320
+ if done:
321
+ break
322
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
323
+ returns.append(return_)
324
+
325
+ env.close()
326
+
327
+ if concatenate_all_replay:
328
+ images = np.concatenate(images, axis=0)
329
+ import imageio
330
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
331
+
332
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
333
+
334
+ def collect_data(
335
+ self,
336
+ env_num: int = 8,
337
+ save_data_path: Optional[str] = None,
338
+ n_sample: Optional[int] = None,
339
+ n_episode: Optional[int] = None,
340
+ context: Optional[str] = None,
341
+ debug: bool = False
342
+ ) -> None:
343
+ """
344
+ Overview:
345
+ Collect data with DQN algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
346
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
347
+ ``exp_name/demo_data``.
348
+ Arguments:
349
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
350
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
351
+ If not specified, the data will be saved in ``exp_name/demo_data``.
352
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
353
+ If not specified, ``n_episode`` must be specified.
354
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
355
+ If not specified, ``n_sample`` must be specified.
356
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
357
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
358
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
359
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
360
+ subprocess environment manager will be used.
361
+ """
362
+
363
+ if debug:
364
+ logging.getLogger().setLevel(logging.DEBUG)
365
+ if n_episode is not None:
366
+ raise NotImplementedError
367
+ # define env and policy
368
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
369
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
370
+
371
+ if save_data_path is None:
372
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
373
+
374
+ # main execution task
375
+ with task.start(ctx=OnlineRLContext()):
376
+ task.use(
377
+ StepCollector(
378
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
379
+ )
380
+ )
381
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
382
+ task.run(max_step=1)
383
+ logging.info(
384
+ f'DQN collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
385
+ )
386
+
387
+ def batch_evaluate(
388
+ self,
389
+ env_num: int = 4,
390
+ n_evaluator_episode: int = 4,
391
+ context: Optional[str] = None,
392
+ debug: bool = False
393
+ ) -> EvalReturn:
394
+ """
395
+ Overview:
396
+ Evaluate the agent with DQN algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
397
+ environments. The evaluation result will be returned.
398
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
399
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
400
+ will only create one evaluator environment to evaluate the agent and save the replay video.
401
+ Arguments:
402
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
403
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
404
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
405
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
406
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
407
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
408
+ subprocess environment manager will be used.
409
+ Returns:
410
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
411
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
412
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
413
+ """
414
+
415
+ if debug:
416
+ logging.getLogger().setLevel(logging.DEBUG)
417
+ # define env and policy
418
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
419
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
420
+
421
+ # reset first to make sure the env is in the initial state
422
+ # env will be reset again in the main loop
423
+ env.launch()
424
+ env.reset()
425
+
426
+ evaluate_cfg = self.cfg
427
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
428
+
429
+ # main execution task
430
+ with task.start(ctx=OnlineRLContext()):
431
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
432
+ task.run(max_step=1)
433
+
434
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
435
+
436
+ @property
437
+ def best(self) -> 'DQNAgent':
438
+ """
439
+ Overview:
440
+ Load the best model from the checkpoint directory, \
441
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
442
+ The return value is the agent with the best model.
443
+ Returns:
444
+ - (:obj:`DQNAgent`): The agent with the best model.
445
+ Examples:
446
+ >>> agent = DQNAgent(env_id='LunarLander-v2')
447
+ >>> agent.train()
448
+ >>> agent = agent.best
449
+
450
+ .. note::
451
+ The best model is the model with the highest evaluation return. If this method is called, the current \
452
+ model will be replaced by the best model.
453
+ """
454
+
455
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
456
+ # Load best model if it exists
457
+ if os.path.exists(best_model_file_path):
458
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
459
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
460
+ return self
DI-engine/ding/bonus/model.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional
2
+ from easydict import EasyDict
3
+ import torch
4
+ import torch.nn as nn
5
+ import treetensor.torch as ttorch
6
+ from copy import deepcopy
7
+ from ding.utils import SequenceType, squeeze
8
+ from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \
9
+ FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead
10
+ from ding.torch_utils import MLP, fc_block
11
+
12
+
13
+ class DiscretePolicyHead(nn.Module):
14
+
15
+ def __init__(
16
+ self,
17
+ hidden_size: int,
18
+ output_size: int,
19
+ layer_num: int = 1,
20
+ activation: Optional[nn.Module] = nn.ReLU(),
21
+ norm_type: Optional[str] = None,
22
+ ) -> None:
23
+ super(DiscretePolicyHead, self).__init__()
24
+ self.main = nn.Sequential(
25
+ MLP(
26
+ hidden_size,
27
+ hidden_size,
28
+ hidden_size,
29
+ layer_num,
30
+ layer_fn=nn.Linear,
31
+ activation=activation,
32
+ norm_type=norm_type
33
+ ), fc_block(hidden_size, output_size)
34
+ )
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ return self.main(x)
38
+
39
+
40
+ class PPOFModel(nn.Module):
41
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
42
+
43
+ def __init__(
44
+ self,
45
+ obs_shape: Union[int, SequenceType],
46
+ action_shape: Union[int, SequenceType, EasyDict],
47
+ action_space: str = 'discrete',
48
+ share_encoder: bool = True,
49
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
50
+ actor_head_hidden_size: int = 64,
51
+ actor_head_layer_num: int = 1,
52
+ critic_head_hidden_size: int = 64,
53
+ critic_head_layer_num: int = 1,
54
+ activation: Optional[nn.Module] = nn.ReLU(),
55
+ norm_type: Optional[str] = None,
56
+ sigma_type: Optional[str] = 'independent',
57
+ fixed_sigma_value: Optional[int] = 0.3,
58
+ bound_type: Optional[str] = None,
59
+ encoder: Optional[torch.nn.Module] = None,
60
+ popart_head=False,
61
+ ) -> None:
62
+ super(PPOFModel, self).__init__()
63
+ obs_shape = squeeze(obs_shape)
64
+ action_shape = squeeze(action_shape)
65
+ self.obs_shape, self.action_shape = obs_shape, action_shape
66
+ self.share_encoder = share_encoder
67
+
68
+ # Encoder Type
69
+ def new_encoder(outsize):
70
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
71
+ return FCEncoder(
72
+ obs_shape=obs_shape,
73
+ hidden_size_list=encoder_hidden_size_list,
74
+ activation=activation,
75
+ norm_type=norm_type
76
+ )
77
+ elif len(obs_shape) == 3:
78
+ return ConvEncoder(
79
+ obs_shape=obs_shape,
80
+ hidden_size_list=encoder_hidden_size_list,
81
+ activation=activation,
82
+ norm_type=norm_type
83
+ )
84
+ else:
85
+ raise RuntimeError(
86
+ "not support obs_shape for pre-defined encoder: {}, please customize your own encoder".
87
+ format(obs_shape)
88
+ )
89
+
90
+ if self.share_encoder:
91
+ assert actor_head_hidden_size == critic_head_hidden_size, \
92
+ "actor and critic network head should have same size."
93
+ if encoder:
94
+ if isinstance(encoder, torch.nn.Module):
95
+ self.encoder = encoder
96
+ else:
97
+ raise ValueError("illegal encoder instance.")
98
+ else:
99
+ self.encoder = new_encoder(actor_head_hidden_size)
100
+ else:
101
+ if encoder:
102
+ if isinstance(encoder, torch.nn.Module):
103
+ self.actor_encoder = encoder
104
+ self.critic_encoder = deepcopy(encoder)
105
+ else:
106
+ raise ValueError("illegal encoder instance.")
107
+ else:
108
+ self.actor_encoder = new_encoder(actor_head_hidden_size)
109
+ self.critic_encoder = new_encoder(critic_head_hidden_size)
110
+
111
+ # Head Type
112
+ if not popart_head:
113
+ self.critic_head = RegressionHead(
114
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
115
+ )
116
+ else:
117
+ self.critic_head = PopArtVHead(
118
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
119
+ )
120
+
121
+ self.action_space = action_space
122
+ assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
123
+ if self.action_space == 'continuous':
124
+ self.multi_head = False
125
+ self.actor_head = ReparameterizationHead(
126
+ actor_head_hidden_size,
127
+ action_shape,
128
+ actor_head_layer_num,
129
+ sigma_type=sigma_type,
130
+ activation=activation,
131
+ norm_type=norm_type,
132
+ bound_type=bound_type
133
+ )
134
+ elif self.action_space == 'discrete':
135
+ actor_head_cls = DiscretePolicyHead
136
+ multi_head = not isinstance(action_shape, int)
137
+ self.multi_head = multi_head
138
+ if multi_head:
139
+ self.actor_head = MultiHead(
140
+ actor_head_cls,
141
+ actor_head_hidden_size,
142
+ action_shape,
143
+ layer_num=actor_head_layer_num,
144
+ activation=activation,
145
+ norm_type=norm_type
146
+ )
147
+ else:
148
+ self.actor_head = actor_head_cls(
149
+ actor_head_hidden_size,
150
+ action_shape,
151
+ actor_head_layer_num,
152
+ activation=activation,
153
+ norm_type=norm_type
154
+ )
155
+ elif self.action_space == 'hybrid': # HPPO
156
+ # hybrid action space: action_type(discrete) + action_args(continuous),
157
+ # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
158
+ action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
159
+ action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
160
+ actor_action_args = ReparameterizationHead(
161
+ actor_head_hidden_size,
162
+ action_shape.action_args_shape,
163
+ actor_head_layer_num,
164
+ sigma_type=sigma_type,
165
+ fixed_sigma_value=fixed_sigma_value,
166
+ activation=activation,
167
+ norm_type=norm_type,
168
+ bound_type=bound_type,
169
+ )
170
+ actor_action_type = DiscretePolicyHead(
171
+ actor_head_hidden_size,
172
+ action_shape.action_type_shape,
173
+ actor_head_layer_num,
174
+ activation=activation,
175
+ norm_type=norm_type,
176
+ )
177
+ self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])
178
+
179
+ # must use list, not nn.ModuleList
180
+ if self.share_encoder:
181
+ self.actor = [self.encoder, self.actor_head]
182
+ self.critic = [self.encoder, self.critic_head]
183
+ else:
184
+ self.actor = [self.actor_encoder, self.actor_head]
185
+ self.critic = [self.critic_encoder, self.critic_head]
186
+ # Convenient for calling some apis (e.g. self.critic.parameters()),
187
+ # but may cause misunderstanding when `print(self)`
188
+ self.actor = nn.ModuleList(self.actor)
189
+ self.critic = nn.ModuleList(self.critic)
190
+
191
+ def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
192
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
193
+ return getattr(self, mode)(inputs)
194
+
195
+ def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
196
+ if self.share_encoder:
197
+ x = self.encoder(x)
198
+ else:
199
+ x = self.actor_encoder(x)
200
+
201
+ if self.action_space == 'discrete':
202
+ return self.actor_head(x)
203
+ elif self.action_space == 'continuous':
204
+ x = self.actor_head(x) # mu, sigma
205
+ return ttorch.as_tensor(x)
206
+ elif self.action_space == 'hybrid':
207
+ action_type = self.actor_head[0](x)
208
+ action_args = self.actor_head[1](x)
209
+ return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args})
210
+
211
+ def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
212
+ if self.share_encoder:
213
+ x = self.encoder(x)
214
+ else:
215
+ x = self.critic_encoder(x)
216
+ x = self.critic_head(x)
217
+ return x
218
+
219
+ def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
220
+ if self.share_encoder:
221
+ actor_embedding = critic_embedding = self.encoder(x)
222
+ else:
223
+ actor_embedding = self.actor_encoder(x)
224
+ critic_embedding = self.critic_encoder(x)
225
+
226
+ value = self.critic_head(critic_embedding)
227
+
228
+ if self.action_space == 'discrete':
229
+ logit = self.actor_head(actor_embedding)
230
+ return ttorch.as_tensor({'logit': logit, 'value': value['pred']})
231
+ elif self.action_space == 'continuous':
232
+ x = self.actor_head(actor_embedding)
233
+ return ttorch.as_tensor({'logit': x, 'value': value['pred']})
234
+ elif self.action_space == 'hybrid':
235
+ action_type = self.actor_head[0](actor_embedding)
236
+ action_args = self.actor_head[1](actor_embedding)
237
+ return ttorch.as_tensor(
238
+ {
239
+ 'logit': {
240
+ 'action_type': action_type,
241
+ 'action_args': action_args
242
+ },
243
+ 'value': value['pred']
244
+ }
245
+ )
DI-engine/ding/bonus/pg.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, trainer, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
11
+ montecarlo_return_estimator, final_ctx_saver, EpisodeCollector
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import PGPolicy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import PG
19
+ from ding.bonus.common import TrainingReturn, EvalReturn
20
+ from ding.config.example.PG import supported_env_cfg
21
+ from ding.config.example.PG import supported_env
22
+
23
+
24
+ class PGAgent:
25
+ """
26
+ Overview:
27
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Policy Gradient(PG).
28
+ For more information about the system design of RL agent, please refer to \
29
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
30
+ Interface:
31
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
32
+ """
33
+ supported_env_list = list(supported_env_cfg.keys())
34
+ """
35
+ Overview:
36
+ List of supported envs.
37
+ Examples:
38
+ >>> from ding.bonus.pg import PGAgent
39
+ >>> print(PGAgent.supported_env_list)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ env_id: str = None,
45
+ env: BaseEnv = None,
46
+ seed: int = 0,
47
+ exp_name: str = None,
48
+ model: Optional[torch.nn.Module] = None,
49
+ cfg: Optional[Union[EasyDict, dict]] = None,
50
+ policy_state_dict: str = None,
51
+ ) -> None:
52
+ """
53
+ Overview:
54
+ Initialize agent for PG algorithm.
55
+ Arguments:
56
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
57
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
58
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
59
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
60
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
61
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
62
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
63
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
64
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
65
+ Default to 0.
66
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
67
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
68
+ - model (:obj:`torch.nn.Module`): The model of PG algorithm, which should be an instance of class \
69
+ :class:`ding.model.PG`. \
70
+ If not specified, a default model will be generated according to the configuration.
71
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of PG algorithm, which is a dict. \
72
+ Default to None. If not specified, the default configuration will be used. \
73
+ The default configuration can be found in ``ding/config/example/PG/gym_lunarlander_v2.py``.
74
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
75
+ If specified, the policy will be loaded from this file. Default to None.
76
+
77
+ .. note::
78
+ An RL Agent Instance can be initialized in two basic ways. \
79
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
80
+ and we want to train an agent with PG algorithm with default configuration. \
81
+ Then we can initialize the agent in the following ways:
82
+ >>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
83
+ or, if we want can specify the env_id in the configuration:
84
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
85
+ >>> agent = PGAgent(cfg=cfg)
86
+ There are also other arguments to specify the agent when initializing.
87
+ For example, if we want to specify the environment instance:
88
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
89
+ >>> agent = PGAgent(cfg=cfg, env=env)
90
+ or, if we want to specify the model:
91
+ >>> model = PG(**cfg.policy.model)
92
+ >>> agent = PGAgent(cfg=cfg, model=model)
93
+ or, if we want to reload the policy from a saved policy state dict:
94
+ >>> agent = PGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
95
+ Make sure that the configuration is consistent with the saved policy state dict.
96
+ """
97
+
98
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
99
+
100
+ if cfg is not None and not isinstance(cfg, EasyDict):
101
+ cfg = EasyDict(cfg)
102
+
103
+ if env_id is not None:
104
+ assert env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
105
+ PGAgent.supported_env_list
106
+ )
107
+ if cfg is None:
108
+ cfg = supported_env_cfg[env_id]
109
+ else:
110
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
111
+ else:
112
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
113
+ assert cfg.env.env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
114
+ PGAgent.supported_env_list
115
+ )
116
+ default_policy_config = EasyDict({"policy": PGPolicy.default_config()})
117
+ default_policy_config.update(cfg)
118
+ cfg = default_policy_config
119
+
120
+ if exp_name is not None:
121
+ cfg.exp_name = exp_name
122
+ self.cfg = compile_config(cfg, policy=PGPolicy)
123
+ self.exp_name = self.cfg.exp_name
124
+ if env is None:
125
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
126
+ else:
127
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
128
+ self.env = env
129
+
130
+ logging.getLogger().setLevel(logging.INFO)
131
+ self.seed = seed
132
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
133
+ if not os.path.exists(self.exp_name):
134
+ os.makedirs(self.exp_name)
135
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
136
+ if model is None:
137
+ model = PG(**self.cfg.policy.model)
138
+ self.policy = PGPolicy(self.cfg.policy, model=model)
139
+ if policy_state_dict is not None:
140
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
141
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
142
+
143
+ def train(
144
+ self,
145
+ step: int = int(1e7),
146
+ collector_env_num: int = None,
147
+ evaluator_env_num: int = None,
148
+ n_iter_save_ckpt: int = 1000,
149
+ context: Optional[str] = None,
150
+ debug: bool = False,
151
+ wandb_sweep: bool = False,
152
+ ) -> TrainingReturn:
153
+ """
154
+ Overview:
155
+ Train the agent with PG algorithm for ``step`` iterations with ``collector_env_num`` collector \
156
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
157
+ recorded and saved by wandb.
158
+ Arguments:
159
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
160
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
161
+ If not specified, it will be set according to the configuration.
162
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
163
+ If not specified, it will be set according to the configuration.
164
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
165
+ Default to 1000.
166
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
167
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
168
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
169
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
170
+ subprocess environment manager will be used.
171
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
172
+ which is a hyper-parameter optimization process for seeking the best configurations. \
173
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
174
+ Returns:
175
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
176
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
177
+ """
178
+
179
+ if debug:
180
+ logging.getLogger().setLevel(logging.DEBUG)
181
+ logging.debug(self.policy._model)
182
+ # define env and policy
183
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
184
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
185
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
186
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
187
+
188
+ with task.start(ctx=OnlineRLContext()):
189
+ task.use(
190
+ interaction_evaluator(
191
+ self.cfg,
192
+ self.policy.eval_mode,
193
+ evaluator_env,
194
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
195
+ )
196
+ )
197
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
198
+ task.use(EpisodeCollector(self.cfg, self.policy.collect_mode, collector_env))
199
+ task.use(montecarlo_return_estimator(self.policy))
200
+ task.use(trainer(self.cfg, self.policy.learn_mode))
201
+ task.use(
202
+ wandb_online_logger(
203
+ metric_list=self.policy._monitor_vars_learn(),
204
+ model=self.policy._model,
205
+ anonymous=True,
206
+ project_name=self.exp_name,
207
+ wandb_sweep=wandb_sweep,
208
+ )
209
+ )
210
+ task.use(termination_checker(max_env_step=step))
211
+ task.use(final_ctx_saver(name=self.exp_name))
212
+ task.run()
213
+
214
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
215
+
216
+ def deploy(
217
+ self,
218
+ enable_save_replay: bool = False,
219
+ concatenate_all_replay: bool = False,
220
+ replay_save_path: str = None,
221
+ seed: Optional[Union[int, List]] = None,
222
+ debug: bool = False
223
+ ) -> EvalReturn:
224
+ """
225
+ Overview:
226
+ Deploy the agent with PG algorithm by interacting with the environment, during which the replay video \
227
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
228
+ Arguments:
229
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
230
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
231
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
232
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
233
+ the replay video of each episode will be saved separately.
234
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
235
+ If not specified, the video will be saved in ``exp_name/videos``.
236
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
237
+ Default to None. If not specified, ``self.seed`` will be used. \
238
+ If ``seed`` is an integer, the agent will be deployed once. \
239
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
240
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
241
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
242
+ subprocess environment manager will be used.
243
+ Returns:
244
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
245
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
246
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
247
+ """
248
+
249
+ if debug:
250
+ logging.getLogger().setLevel(logging.DEBUG)
251
+ # define env and policy
252
+ env = self.env.clone(caller='evaluator')
253
+
254
+ if seed is not None and isinstance(seed, int):
255
+ seeds = [seed]
256
+ elif seed is not None and isinstance(seed, list):
257
+ seeds = seed
258
+ else:
259
+ seeds = [self.seed]
260
+
261
+ returns = []
262
+ images = []
263
+ if enable_save_replay:
264
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
265
+ env.enable_save_replay(replay_path=replay_save_path)
266
+ else:
267
+ logging.warning('No video would be generated during the deploy.')
268
+ if concatenate_all_replay:
269
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
270
+ concatenate_all_replay = False
271
+
272
+ def single_env_forward_wrapper(forward_fn, cuda=True):
273
+
274
+ def _forward(obs):
275
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
276
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
277
+ if cuda and torch.cuda.is_available():
278
+ obs = obs.cuda()
279
+ output = forward_fn(obs)
280
+ if self.policy._cfg.deterministic_eval:
281
+ if self.policy._cfg.action_space == 'discrete':
282
+ output['action'] = output['logit'].argmax(dim=-1)
283
+ elif self.policy._cfg.action_space == 'continuous':
284
+ output['action'] = output['logit']['mu']
285
+ else:
286
+ raise KeyError("invalid action_space: {}".format(self.policy._cfg.action_space))
287
+ else:
288
+ output['action'] = output['dist'].sample()
289
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
290
+ action = output['action'].squeeze(0).detach().cpu().numpy()
291
+ return action
292
+
293
+ return _forward
294
+
295
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
296
+
297
+ # reset first to make sure the env is in the initial state
298
+ # env will be reset again in the main loop
299
+ env.reset()
300
+
301
+ for seed in seeds:
302
+ env.seed(seed, dynamic_seed=False)
303
+ return_ = 0.
304
+ step = 0
305
+ obs = env.reset()
306
+ images.append(render(env)[None]) if concatenate_all_replay else None
307
+ while True:
308
+ action = forward_fn(obs)
309
+ obs, rew, done, info = env.step(action)
310
+ images.append(render(env)[None]) if concatenate_all_replay else None
311
+ return_ += rew
312
+ step += 1
313
+ if done:
314
+ break
315
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
316
+ returns.append(return_)
317
+
318
+ env.close()
319
+
320
+ if concatenate_all_replay:
321
+ images = np.concatenate(images, axis=0)
322
+ import imageio
323
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
324
+
325
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
326
+
327
+ def collect_data(
328
+ self,
329
+ env_num: int = 8,
330
+ save_data_path: Optional[str] = None,
331
+ n_sample: Optional[int] = None,
332
+ n_episode: Optional[int] = None,
333
+ context: Optional[str] = None,
334
+ debug: bool = False
335
+ ) -> None:
336
+ """
337
+ Overview:
338
+ Collect data with PG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
339
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
340
+ ``exp_name/demo_data``.
341
+ Arguments:
342
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
343
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
344
+ If not specified, the data will be saved in ``exp_name/demo_data``.
345
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
346
+ If not specified, ``n_episode`` must be specified.
347
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
348
+ If not specified, ``n_sample`` must be specified.
349
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
350
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
351
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
352
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
353
+ subprocess environment manager will be used.
354
+ """
355
+
356
+ if debug:
357
+ logging.getLogger().setLevel(logging.DEBUG)
358
+ if n_episode is not None:
359
+ raise NotImplementedError
360
+ # define env and policy
361
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
362
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
363
+
364
+ if save_data_path is None:
365
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
366
+
367
+ # main execution task
368
+ with task.start(ctx=OnlineRLContext()):
369
+ task.use(
370
+ StepCollector(
371
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
372
+ )
373
+ )
374
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
375
+ task.run(max_step=1)
376
+ logging.info(
377
+ f'PG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
378
+ )
379
+
380
+ def batch_evaluate(
381
+ self,
382
+ env_num: int = 4,
383
+ n_evaluator_episode: int = 4,
384
+ context: Optional[str] = None,
385
+ debug: bool = False
386
+ ) -> EvalReturn:
387
+ """
388
+ Overview:
389
+ Evaluate the agent with PG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
390
+ environments. The evaluation result will be returned.
391
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
392
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
393
+ will only create one evaluator environment to evaluate the agent and save the replay video.
394
+ Arguments:
395
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
396
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
397
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
398
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
399
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
400
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
401
+ subprocess environment manager will be used.
402
+ Returns:
403
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
404
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
405
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
406
+ """
407
+
408
+ if debug:
409
+ logging.getLogger().setLevel(logging.DEBUG)
410
+ # define env and policy
411
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
412
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
413
+
414
+ # reset first to make sure the env is in the initial state
415
+ # env will be reset again in the main loop
416
+ env.launch()
417
+ env.reset()
418
+
419
+ evaluate_cfg = self.cfg
420
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
421
+
422
+ # main execution task
423
+ with task.start(ctx=OnlineRLContext()):
424
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
425
+ task.run(max_step=1)
426
+
427
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
428
+
429
+ @property
430
+ def best(self) -> 'PGAgent':
431
+ """
432
+ Overview:
433
+ Load the best model from the checkpoint directory, \
434
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
435
+ The return value is the agent with the best model.
436
+ Returns:
437
+ - (:obj:`PGAgent`): The agent with the best model.
438
+ Examples:
439
+ >>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
440
+ >>> agent.train()
441
+ >>> agent = agent.best
442
+
443
+ .. note::
444
+ The best model is the model with the highest evaluation return. If this method is called, the current \
445
+ model will be replaced by the best model.
446
+ """
447
+
448
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
449
+ # Load best model if it exists
450
+ if os.path.exists(best_model_file_path):
451
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
452
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
453
+ return self
DI-engine/ding/bonus/ppo_offpolicy.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, final_ctx_saver, OffPolicyLearner, StepCollector, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, gae_estimator
11
+ from ding.envs import BaseEnv
12
+ from ding.envs import setup_ding_env_manager
13
+ from ding.policy import PPOOffPolicy
14
+ from ding.utils import set_pkg_seed
15
+ from ding.utils import get_env_fps, render
16
+ from ding.config import save_config_py, compile_config
17
+ from ding.model import VAC
18
+ from ding.model import model_wrap
19
+ from ding.data import DequeBuffer
20
+ from ding.bonus.common import TrainingReturn, EvalReturn
21
+ from ding.config.example.PPOOffPolicy import supported_env_cfg
22
+ from ding.config.example.PPOOffPolicy import supported_env
23
+
24
+
25
+ class PPOOffPolicyAgent:
26
+ """
27
+ Overview:
28
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29
+ Proximal Policy Optimization(PPO) in an off-policy style.
30
+ For more information about the system design of RL agent, please refer to \
31
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
32
+ Interface:
33
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34
+ """
35
+ supported_env_list = list(supported_env_cfg.keys())
36
+ """
37
+ Overview:
38
+ List of supported envs.
39
+ Examples:
40
+ >>> from ding.bonus.ppo_offpolicy import PPOOffPolicyAgent
41
+ >>> print(PPOOffPolicyAgent.supported_env_list)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_id: str = None,
47
+ env: BaseEnv = None,
48
+ seed: int = 0,
49
+ exp_name: str = None,
50
+ model: Optional[torch.nn.Module] = None,
51
+ cfg: Optional[Union[EasyDict, dict]] = None,
52
+ policy_state_dict: str = None,
53
+ ) -> None:
54
+ """
55
+ Overview:
56
+ Initialize agent for PPO (offpolicy) algorithm.
57
+ Arguments:
58
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
67
+ Default to 0.
68
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70
+ - model (:obj:`torch.nn.Module`): The model of PPO (offpolicy) algorithm, \
71
+ which should be an instance of class :class:`ding.model.VAC`. \
72
+ If not specified, a default model will be generated according to the configuration.
73
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of PPO (offpolicy) algorithm, which is a dict. \
74
+ Default to None. If not specified, the default configuration will be used. \
75
+ The default configuration can be found in ``ding/config/example/PPO (offpolicy)/gym_lunarlander_v2.py``.
76
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77
+ If specified, the policy will be loaded from this file. Default to None.
78
+
79
+ .. note::
80
+ An RL Agent Instance can be initialized in two basic ways. \
81
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
82
+ and we want to train an agent with PPO (offpolicy) algorithm with default configuration. \
83
+ Then we can initialize the agent in the following ways:
84
+ >>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
85
+ or, if we want can specify the env_id in the configuration:
86
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
87
+ >>> agent = PPOOffPolicyAgent(cfg=cfg)
88
+ There are also other arguments to specify the agent when initializing.
89
+ For example, if we want to specify the environment instance:
90
+ >>> env = CustomizedEnv('LunarLander-v2')
91
+ >>> agent = PPOOffPolicyAgent(cfg=cfg, env=env)
92
+ or, if we want to specify the model:
93
+ >>> model = VAC(**cfg.policy.model)
94
+ >>> agent = PPOOffPolicyAgent(cfg=cfg, model=model)
95
+ or, if we want to reload the policy from a saved policy state dict:
96
+ >>> agent = PPOOffPolicyAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
97
+ Make sure that the configuration is consistent with the saved policy state dict.
98
+ """
99
+
100
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101
+
102
+ if cfg is not None and not isinstance(cfg, EasyDict):
103
+ cfg = EasyDict(cfg)
104
+
105
+ if env_id is not None:
106
+ assert env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
107
+ PPOOffPolicyAgent.supported_env_list
108
+ )
109
+ if cfg is None:
110
+ cfg = supported_env_cfg[env_id]
111
+ else:
112
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113
+ else:
114
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115
+ assert cfg.env.env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
116
+ PPOOffPolicyAgent.supported_env_list
117
+ )
118
+ default_policy_config = EasyDict({"policy": PPOOffPolicy.default_config()})
119
+ default_policy_config.update(cfg)
120
+ cfg = default_policy_config
121
+
122
+ if exp_name is not None:
123
+ cfg.exp_name = exp_name
124
+ self.cfg = compile_config(cfg, policy=PPOOffPolicy)
125
+ self.exp_name = self.cfg.exp_name
126
+ if env is None:
127
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128
+ else:
129
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130
+ self.env = env
131
+
132
+ logging.getLogger().setLevel(logging.INFO)
133
+ self.seed = seed
134
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135
+ if not os.path.exists(self.exp_name):
136
+ os.makedirs(self.exp_name)
137
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
138
+ if model is None:
139
+ model = VAC(**self.cfg.policy.model)
140
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141
+ self.policy = PPOOffPolicy(self.cfg.policy, model=model)
142
+ if policy_state_dict is not None:
143
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
144
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145
+
146
+ def train(
147
+ self,
148
+ step: int = int(1e7),
149
+ collector_env_num: int = None,
150
+ evaluator_env_num: int = None,
151
+ n_iter_save_ckpt: int = 1000,
152
+ context: Optional[str] = None,
153
+ debug: bool = False,
154
+ wandb_sweep: bool = False,
155
+ ) -> TrainingReturn:
156
+ """
157
+ Overview:
158
+ Train the agent with PPO (offpolicy) algorithm for ``step`` iterations with ``collector_env_num`` \
159
+ collector environments and ``evaluator_env_num`` evaluator environments. \
160
+ Information during training will be recorded and saved by wandb.
161
+ Arguments:
162
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164
+ If not specified, it will be set according to the configuration.
165
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166
+ If not specified, it will be set according to the configuration.
167
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168
+ Default to 1000.
169
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
173
+ subprocess environment manager will be used.
174
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175
+ which is a hyper-parameter optimization process for seeking the best configurations. \
176
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
177
+ Returns:
178
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
179
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180
+ """
181
+
182
+ if debug:
183
+ logging.getLogger().setLevel(logging.DEBUG)
184
+ logging.debug(self.policy._model)
185
+ # define env and policy
186
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190
+
191
+ with task.start(ctx=OnlineRLContext()):
192
+ task.use(
193
+ interaction_evaluator(
194
+ self.cfg,
195
+ self.policy.eval_mode,
196
+ evaluator_env,
197
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198
+ )
199
+ )
200
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201
+ task.use(
202
+ StepCollector(
203
+ self.cfg,
204
+ self.policy.collect_mode,
205
+ collector_env,
206
+ random_collect_size=self.cfg.policy.random_collect_size
207
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
208
+ )
209
+ )
210
+ task.use(gae_estimator(self.cfg, self.policy.collect_mode, self.buffer_))
211
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
212
+ task.use(
213
+ wandb_online_logger(
214
+ cfg=self.cfg.wandb_logger,
215
+ exp_config=self.cfg,
216
+ metric_list=self.policy._monitor_vars_learn(),
217
+ model=self.policy._model,
218
+ anonymous=True,
219
+ project_name=self.exp_name,
220
+ wandb_sweep=wandb_sweep,
221
+ )
222
+ )
223
+ task.use(termination_checker(max_env_step=step))
224
+ task.use(final_ctx_saver(name=self.exp_name))
225
+ task.run()
226
+
227
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
228
+
229
+ def deploy(
230
+ self,
231
+ enable_save_replay: bool = False,
232
+ concatenate_all_replay: bool = False,
233
+ replay_save_path: str = None,
234
+ seed: Optional[Union[int, List]] = None,
235
+ debug: bool = False
236
+ ) -> EvalReturn:
237
+ """
238
+ Overview:
239
+ Deploy the agent with PPO (offpolicy) algorithm by interacting with the environment, \
240
+ during which the replay video can be saved if ``enable_save_replay`` is True. \
241
+ The evaluation result will be returned.
242
+ Arguments:
243
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
244
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
245
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
246
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
247
+ the replay video of each episode will be saved separately.
248
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
249
+ If not specified, the video will be saved in ``exp_name/videos``.
250
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
251
+ Default to None. If not specified, ``self.seed`` will be used. \
252
+ If ``seed`` is an integer, the agent will be deployed once. \
253
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
254
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
255
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
256
+ subprocess environment manager will be used.
257
+ Returns:
258
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
259
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
260
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
261
+ """
262
+
263
+ if debug:
264
+ logging.getLogger().setLevel(logging.DEBUG)
265
+ # define env and policy
266
+ env = self.env.clone(caller='evaluator')
267
+
268
+ if seed is not None and isinstance(seed, int):
269
+ seeds = [seed]
270
+ elif seed is not None and isinstance(seed, list):
271
+ seeds = seed
272
+ else:
273
+ seeds = [self.seed]
274
+
275
+ returns = []
276
+ images = []
277
+ if enable_save_replay:
278
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
279
+ env.enable_save_replay(replay_path=replay_save_path)
280
+ else:
281
+ logging.warning('No video would be generated during the deploy.')
282
+ if concatenate_all_replay:
283
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
284
+ concatenate_all_replay = False
285
+
286
+ def single_env_forward_wrapper(forward_fn, cuda=True):
287
+
288
+ if self.cfg.policy.action_space == 'discrete':
289
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
290
+ elif self.cfg.policy.action_space == 'continuous':
291
+ forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
292
+ elif self.cfg.policy.action_space == 'hybrid':
293
+ forward_fn = model_wrap(forward_fn, wrapper_name='hybrid_deterministic_argmax_sample').forward
294
+ elif self.cfg.policy.action_space == 'general':
295
+ forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
296
+ else:
297
+ raise NotImplementedError
298
+
299
+ def _forward(obs):
300
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
301
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
302
+ if cuda and torch.cuda.is_available():
303
+ obs = obs.cuda()
304
+ action = forward_fn(obs, mode='compute_actor')["action"]
305
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
306
+ action = action.squeeze(0).detach().cpu().numpy()
307
+ return action
308
+
309
+ return _forward
310
+
311
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
312
+
313
+ # reset first to make sure the env is in the initial state
314
+ # env will be reset again in the main loop
315
+ env.reset()
316
+
317
+ for seed in seeds:
318
+ env.seed(seed, dynamic_seed=False)
319
+ return_ = 0.
320
+ step = 0
321
+ obs = env.reset()
322
+ images.append(render(env)[None]) if concatenate_all_replay else None
323
+ while True:
324
+ action = forward_fn(obs)
325
+ obs, rew, done, info = env.step(action)
326
+ images.append(render(env)[None]) if concatenate_all_replay else None
327
+ return_ += rew
328
+ step += 1
329
+ if done:
330
+ break
331
+ logging.info(f'PPO (offpolicy) deploy is finished, final episode return with {step} steps is: {return_}')
332
+ returns.append(return_)
333
+
334
+ env.close()
335
+
336
+ if concatenate_all_replay:
337
+ images = np.concatenate(images, axis=0)
338
+ import imageio
339
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
340
+
341
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
342
+
343
+ def collect_data(
344
+ self,
345
+ env_num: int = 8,
346
+ save_data_path: Optional[str] = None,
347
+ n_sample: Optional[int] = None,
348
+ n_episode: Optional[int] = None,
349
+ context: Optional[str] = None,
350
+ debug: bool = False
351
+ ) -> None:
352
+ """
353
+ Overview:
354
+ Collect data with PPO (offpolicy) algorithm for ``n_episode`` episodes \
355
+ with ``env_num`` collector environments. \
356
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
357
+ ``exp_name/demo_data``.
358
+ Arguments:
359
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
360
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
361
+ If not specified, the data will be saved in ``exp_name/demo_data``.
362
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
363
+ If not specified, ``n_episode`` must be specified.
364
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
365
+ If not specified, ``n_sample`` must be specified.
366
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
367
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
368
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
369
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
370
+ subprocess environment manager will be used.
371
+ """
372
+
373
+ if debug:
374
+ logging.getLogger().setLevel(logging.DEBUG)
375
+ if n_episode is not None:
376
+ raise NotImplementedError
377
+ # define env and policy
378
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
379
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
380
+
381
+ if save_data_path is None:
382
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
383
+
384
+ # main execution task
385
+ with task.start(ctx=OnlineRLContext()):
386
+ task.use(
387
+ StepCollector(
388
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
389
+ )
390
+ )
391
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
392
+ task.run(max_step=1)
393
+ logging.info(
394
+ f'PPOOffPolicy collecting is finished, more than {n_sample} \
395
+ samples are collected and saved in `{save_data_path}`'
396
+ )
397
+
398
+ def batch_evaluate(
399
+ self,
400
+ env_num: int = 4,
401
+ n_evaluator_episode: int = 4,
402
+ context: Optional[str] = None,
403
+ debug: bool = False
404
+ ) -> EvalReturn:
405
+ """
406
+ Overview:
407
+ Evaluate the agent with PPO (offpolicy) algorithm for ``n_evaluator_episode`` episodes \
408
+ with ``env_num`` evaluator environments. The evaluation result will be returned.
409
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
410
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
411
+ will only create one evaluator environment to evaluate the agent and save the replay video.
412
+ Arguments:
413
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
414
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
415
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
416
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
417
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
418
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
419
+ subprocess environment manager will be used.
420
+ Returns:
421
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
422
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
423
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
424
+ """
425
+
426
+ if debug:
427
+ logging.getLogger().setLevel(logging.DEBUG)
428
+ # define env and policy
429
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
430
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
431
+
432
+ # reset first to make sure the env is in the initial state
433
+ # env will be reset again in the main loop
434
+ env.launch()
435
+ env.reset()
436
+
437
+ evaluate_cfg = self.cfg
438
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
439
+
440
+ # main execution task
441
+ with task.start(ctx=OnlineRLContext()):
442
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
443
+ task.run(max_step=1)
444
+
445
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
446
+
447
+ @property
448
+ def best(self) -> 'PPOOffPolicyAgent':
449
+ """
450
+ Overview:
451
+ Load the best model from the checkpoint directory, \
452
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
453
+ The return value is the agent with the best model.
454
+ Returns:
455
+ - (:obj:`PPOOffPolicyAgent`): The agent with the best model.
456
+ Examples:
457
+ >>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
458
+ >>> agent.train()
459
+ >>> agent.best
460
+
461
+ .. note::
462
+ The best model is the model with the highest evaluation return. If this method is called, the current \
463
+ model will be replaced by the best model.
464
+ """
465
+
466
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
467
+ # Load best model if it exists
468
+ if os.path.exists(best_model_file_path):
469
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
470
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
471
+ return self
DI-engine/ding/bonus/ppof.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ from functools import partial
5
+ import os
6
+ import gym
7
+ import gymnasium
8
+ import numpy as np
9
+ import torch
10
+ from ding.framework import task, OnlineRLContext
11
+ from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
12
+ wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
13
+ from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
14
+ from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py
18
+ from .model import PPOFModel
19
+ from .config import get_instance_config, get_instance_env, get_hybrid_shape
20
+ from ding.bonus.common import TrainingReturn, EvalReturn
21
+
22
+
23
+ class PPOF:
24
+ """
25
+ Overview:
26
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
27
+ Proximal Policy Optimization(PPO).
28
+ For more information about the system design of RL agent, please refer to \
29
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
30
+ Interface:
31
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
32
+ """
33
+
34
+ supported_env_list = [
35
+ # common
36
+ 'LunarLander-v2',
37
+ 'LunarLanderContinuous-v2',
38
+ 'BipedalWalker-v3',
39
+ 'Pendulum-v1',
40
+ 'acrobot',
41
+ # ch2: action
42
+ 'rocket_landing',
43
+ 'drone_fly',
44
+ 'hybrid_moving',
45
+ # ch3: obs
46
+ 'evogym_carrier',
47
+ 'mario',
48
+ 'di_sheep',
49
+ 'procgen_bigfish',
50
+ # ch4: reward
51
+ 'minigrid_fourroom',
52
+ 'metadrive',
53
+ # atari
54
+ 'BowlingNoFrameskip-v4',
55
+ 'BreakoutNoFrameskip-v4',
56
+ 'GopherNoFrameskip-v4'
57
+ 'KangarooNoFrameskip-v4',
58
+ 'PongNoFrameskip-v4',
59
+ 'QbertNoFrameskip-v4',
60
+ 'SpaceInvadersNoFrameskip-v4',
61
+ # mujoco
62
+ 'Hopper-v3',
63
+ 'HalfCheetah-v3',
64
+ 'Walker2d-v3',
65
+ ]
66
+ """
67
+ Overview:
68
+ List of supported envs.
69
+ Examples:
70
+ >>> from ding.bonus.ppof import PPOF
71
+ >>> print(PPOF.supported_env_list)
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ env_id: str = None,
77
+ env: BaseEnv = None,
78
+ seed: int = 0,
79
+ exp_name: str = None,
80
+ model: Optional[torch.nn.Module] = None,
81
+ cfg: Optional[Union[EasyDict, dict]] = None,
82
+ policy_state_dict: str = None
83
+ ) -> None:
84
+ """
85
+ Overview:
86
+ Initialize agent for PPO algorithm.
87
+ Arguments:
88
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
89
+ If ``env_id`` is not specified, ``env_id`` in ``cfg`` must be specified. \
90
+ If ``env_id`` is specified, ``env_id`` in ``cfg`` will be ignored. \
91
+ ``env_id`` should be one of the supported envs, which can be found in ``PPOF.supported_env_list``.
92
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
93
+ If ``env`` is not specified, ``env_id`` or ``cfg.env_id`` must be specified. \
94
+ ``env_id`` or ``cfg.env_id`` will be used to create environment instance. \
95
+ If ``env`` is specified, ``env_id`` and ``cfg.env_id`` will be ignored.
96
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
97
+ Default to 0.
98
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
99
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
100
+ - model (:obj:`torch.nn.Module`): The model of PPO algorithm, which should be an instance of class \
101
+ ``ding.model.PPOFModel``. \
102
+ If not specified, a default model will be generated according to the configuration.
103
+ - cfg (:obj:`Union[EasyDict, dict]`): The configuration of PPO algorithm, which is a dict. \
104
+ Default to None. If not specified, the default configuration will be used.
105
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
106
+ If specified, the policy will be loaded from this file. Default to None.
107
+
108
+ .. note::
109
+ An RL Agent Instance can be initialized in two basic ways. \
110
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
111
+ and we want to train an agent with PPO algorithm with default configuration. \
112
+ Then we can initialize the agent in the following ways:
113
+ >>> agent = PPOF(env_id='LunarLander-v2')
114
+ or, if we want can specify the env_id in the configuration:
115
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
116
+ >>> agent = PPOF(cfg=cfg)
117
+ There are also other arguments to specify the agent when initializing.
118
+ For example, if we want to specify the environment instance:
119
+ >>> env = CustomizedEnv('LunarLander-v2')
120
+ >>> agent = PPOF(cfg=cfg, env=env)
121
+ or, if we want to specify the model:
122
+ >>> model = VAC(**cfg.policy.model)
123
+ >>> agent = PPOF(cfg=cfg, model=model)
124
+ or, if we want to reload the policy from a saved policy state dict:
125
+ >>> agent = PPOF(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
126
+ Make sure that the configuration is consistent with the saved policy state dict.
127
+ """
128
+
129
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
130
+
131
+ if cfg is not None and not isinstance(cfg, EasyDict):
132
+ cfg = EasyDict(cfg)
133
+
134
+ if env_id is not None:
135
+ assert env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(PPOF.supported_env_list)
136
+ if cfg is None:
137
+ cfg = get_instance_config(env_id, algorithm="PPOF")
138
+
139
+ if not hasattr(cfg, "env_id"):
140
+ cfg.env_id = env_id
141
+ assert cfg.env_id == env_id, "env_id in cfg should be the same as env_id in args."
142
+ else:
143
+ assert hasattr(cfg, "env_id"), "Please specify env_id in cfg."
144
+ assert cfg.env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(
145
+ PPOF.supported_env_list
146
+ )
147
+
148
+ if exp_name is not None:
149
+ cfg.exp_name = exp_name
150
+ elif not hasattr(cfg, "exp_name"):
151
+ cfg.exp_name = "{}-{}".format(cfg.env_id, "PPO")
152
+ self.cfg = cfg
153
+ self.exp_name = self.cfg.exp_name
154
+
155
+ if env is None:
156
+ self.env = get_instance_env(self.cfg.env_id)
157
+ else:
158
+ self.env = env
159
+
160
+ logging.getLogger().setLevel(logging.INFO)
161
+ self.seed = seed
162
+ set_pkg_seed(self.seed, use_cuda=self.cfg.cuda)
163
+
164
+ if not os.path.exists(self.exp_name):
165
+ os.makedirs(self.exp_name)
166
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
167
+
168
+ action_space = self.env.action_space
169
+ if isinstance(action_space, (gym.spaces.Discrete, gymnasium.spaces.Discrete)):
170
+ action_shape = int(action_space.n)
171
+ elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
172
+ action_shape = get_hybrid_shape(action_space)
173
+ else:
174
+ action_shape = action_space.shape
175
+
176
+ # Three types of value normalization is supported currently
177
+ assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline']
178
+ if model is None:
179
+ if self.cfg.value_norm != 'popart':
180
+ model = PPOFModel(
181
+ self.env.observation_space.shape,
182
+ action_shape,
183
+ action_space=self.cfg.action_space,
184
+ **self.cfg.model
185
+ )
186
+ else:
187
+ model = PPOFModel(
188
+ self.env.observation_space.shape,
189
+ action_shape,
190
+ action_space=self.cfg.action_space,
191
+ popart_head=True,
192
+ **self.cfg.model
193
+ )
194
+ self.policy = PPOFPolicy(self.cfg, model=model)
195
+ if policy_state_dict is not None:
196
+ self.policy.load_state_dict(policy_state_dict)
197
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
198
+
199
+ def train(
200
+ self,
201
+ step: int = int(1e7),
202
+ collector_env_num: int = 4,
203
+ evaluator_env_num: int = 4,
204
+ n_iter_log_show: int = 500,
205
+ n_iter_save_ckpt: int = 1000,
206
+ context: Optional[str] = None,
207
+ reward_model: Optional[str] = None,
208
+ debug: bool = False,
209
+ wandb_sweep: bool = False,
210
+ ) -> TrainingReturn:
211
+ """
212
+ Overview:
213
+ Train the agent with PPO algorithm for ``step`` iterations with ``collector_env_num`` collector \
214
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
215
+ recorded and saved by wandb.
216
+ Arguments:
217
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
218
+ - collector_env_num (:obj:`int`): The number of collector environments. Default to 4.
219
+ - evaluator_env_num (:obj:`int`): The number of evaluator environments. Default to 4.
220
+ - n_iter_log_show (:obj:`int`): The frequency of logging every training iteration. Default to 500.
221
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
222
+ Default to 1000.
223
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
224
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
225
+ - reward_model (:obj:`str`): The reward model name. Default to None. This argument is not supported yet.
226
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
227
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
228
+ subprocess environment manager will be used.
229
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
230
+ which is a hyper-parameter optimization process for seeking the best configurations. \
231
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
232
+ Returns:
233
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
234
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
235
+ """
236
+
237
+ if debug:
238
+ logging.getLogger().setLevel(logging.DEBUG)
239
+ logging.debug(self.policy._model)
240
+ # define env and policy
241
+ collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
242
+ evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
243
+
244
+ if reward_model is not None:
245
+ # self.reward_model = create_reward_model(reward_model, self.cfg.reward_model)
246
+ pass
247
+
248
+ with task.start(ctx=OnlineRLContext()):
249
+ task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
250
+ task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
251
+ task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
252
+ task.use(ppof_adv_estimator(self.policy))
253
+ task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
254
+ task.use(
255
+ wandb_online_logger(
256
+ metric_list=self.policy.monitor_vars(),
257
+ model=self.policy._model,
258
+ anonymous=True,
259
+ project_name=self.exp_name,
260
+ wandb_sweep=wandb_sweep,
261
+ )
262
+ )
263
+ task.use(termination_checker(max_env_step=step))
264
+ task.run()
265
+
266
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
267
+
268
+ def deploy(
269
+ self,
270
+ enable_save_replay: bool = False,
271
+ concatenate_all_replay: bool = False,
272
+ replay_save_path: str = None,
273
+ seed: Optional[Union[int, List]] = None,
274
+ debug: bool = False
275
+ ) -> EvalReturn:
276
+ """
277
+ Overview:
278
+ Deploy the agent with PPO algorithm by interacting with the environment, during which the replay video \
279
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
280
+ Arguments:
281
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
282
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
283
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
284
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
285
+ the replay video of each episode will be saved separately.
286
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
287
+ If not specified, the video will be saved in ``exp_name/videos``.
288
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
289
+ Default to None. If not specified, ``self.seed`` will be used. \
290
+ If ``seed`` is an integer, the agent will be deployed once. \
291
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
292
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
293
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
294
+ subprocess environment manager will be used.
295
+ Returns:
296
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
297
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
298
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
299
+ """
300
+
301
+ if debug:
302
+ logging.getLogger().setLevel(logging.DEBUG)
303
+ # define env and policy
304
+ env = self.env.clone(caller='evaluator')
305
+
306
+ if seed is not None and isinstance(seed, int):
307
+ seeds = [seed]
308
+ elif seed is not None and isinstance(seed, list):
309
+ seeds = seed
310
+ else:
311
+ seeds = [self.seed]
312
+
313
+ returns = []
314
+ images = []
315
+ if enable_save_replay:
316
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
317
+ env.enable_save_replay(replay_path=replay_save_path)
318
+ else:
319
+ logging.warning('No video would be generated during the deploy.')
320
+ if concatenate_all_replay:
321
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
322
+ concatenate_all_replay = False
323
+
324
+ forward_fn = single_env_forward_wrapper_ttorch(self.policy.eval, self.cfg.cuda)
325
+
326
+ # reset first to make sure the env is in the initial state
327
+ # env will be reset again in the main loop
328
+ env.reset()
329
+
330
+ for seed in seeds:
331
+ env.seed(seed, dynamic_seed=False)
332
+ return_ = 0.
333
+ step = 0
334
+ obs = env.reset()
335
+ images.append(render(env)[None]) if concatenate_all_replay else None
336
+ while True:
337
+ action = forward_fn(obs)
338
+ obs, rew, done, info = env.step(action)
339
+ images.append(render(env)[None]) if concatenate_all_replay else None
340
+ return_ += rew
341
+ step += 1
342
+ if done:
343
+ break
344
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
345
+ returns.append(return_)
346
+
347
+ env.close()
348
+
349
+ if concatenate_all_replay:
350
+ images = np.concatenate(images, axis=0)
351
+ import imageio
352
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
353
+
354
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
355
+
356
+ def collect_data(
357
+ self,
358
+ env_num: int = 8,
359
+ save_data_path: Optional[str] = None,
360
+ n_sample: Optional[int] = None,
361
+ n_episode: Optional[int] = None,
362
+ context: Optional[str] = None,
363
+ debug: bool = False
364
+ ) -> None:
365
+ """
366
+ Overview:
367
+ Collect data with PPO algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
368
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
369
+ ``exp_name/demo_data``.
370
+ Arguments:
371
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
372
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
373
+ If not specified, the data will be saved in ``exp_name/demo_data``.
374
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
375
+ If not specified, ``n_episode`` must be specified.
376
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
377
+ If not specified, ``n_sample`` must be specified.
378
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
379
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
380
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
381
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
382
+ subprocess environment manager will be used.
383
+ """
384
+
385
+ if debug:
386
+ logging.getLogger().setLevel(logging.DEBUG)
387
+ if n_episode is not None:
388
+ raise NotImplementedError
389
+ # define env and policy
390
+ env = self._setup_env_manager(env_num, context, debug, 'collector')
391
+ if save_data_path is None:
392
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
393
+
394
+ # main execution task
395
+ with task.start(ctx=OnlineRLContext()):
396
+ task.use(PPOFStepCollector(self.seed, self.policy, env, n_sample))
397
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
398
+ task.run(max_step=1)
399
+ logging.info(
400
+ f'PPOF collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
401
+ )
402
+
403
+ def batch_evaluate(
404
+ self,
405
+ env_num: int = 4,
406
+ n_evaluator_episode: int = 4,
407
+ context: Optional[str] = None,
408
+ debug: bool = False,
409
+ ) -> EvalReturn:
410
+ """
411
+ Overview:
412
+ Evaluate the agent with PPO algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
413
+ environments. The evaluation result will be returned.
414
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
415
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
416
+ will only create one evaluator environment to evaluate the agent and save the replay video.
417
+ Arguments:
418
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
419
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
420
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
421
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
422
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
423
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
424
+ subprocess environment manager will be used.
425
+ Returns:
426
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
427
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
428
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
429
+ """
430
+
431
+ if debug:
432
+ logging.getLogger().setLevel(logging.DEBUG)
433
+ # define env and policy
434
+ env = self._setup_env_manager(env_num, context, debug, 'evaluator')
435
+
436
+ # reset first to make sure the env is in the initial state
437
+ # env will be reset again in the main loop
438
+ env.launch()
439
+ env.reset()
440
+
441
+ # main execution task
442
+ with task.start(ctx=OnlineRLContext()):
443
+ task.use(interaction_evaluator_ttorch(
444
+ self.seed,
445
+ self.policy,
446
+ env,
447
+ n_evaluator_episode,
448
+ ))
449
+ task.run(max_step=1)
450
+
451
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
452
+
453
+ def _setup_env_manager(
454
+ self,
455
+ env_num: int,
456
+ context: Optional[str] = None,
457
+ debug: bool = False,
458
+ caller: str = 'collector'
459
+ ) -> BaseEnvManagerV2:
460
+ """
461
+ Overview:
462
+ Setup the environment manager. The environment manager is used to manage multiple environments.
463
+ Arguments:
464
+ - env_num (:obj:`int`): The number of environments.
465
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
466
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
467
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
468
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
469
+ subprocess environment manager will be used.
470
+ - caller (:obj:`str`): The caller of the environment manager. Default to 'collector'.
471
+ Returns:
472
+ - (:obj:`BaseEnvManagerV2`): The environment manager.
473
+ """
474
+ assert caller in ['evaluator', 'collector']
475
+ if debug:
476
+ env_cls = BaseEnvManagerV2
477
+ manager_cfg = env_cls.default_config()
478
+ else:
479
+ env_cls = SubprocessEnvManagerV2
480
+ manager_cfg = env_cls.default_config()
481
+ if context is not None:
482
+ manager_cfg.context = context
483
+ return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg)
484
+
485
+ @property
486
+ def best(self) -> 'PPOF':
487
+ """
488
+ Overview:
489
+ Load the best model from the checkpoint directory, \
490
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
491
+ The return value is the agent with the best model.
492
+ Returns:
493
+ - (:obj:`PPOF`): The agent with the best model.
494
+ Examples:
495
+ >>> agent = PPOF(env_id='LunarLander-v2')
496
+ >>> agent.train()
497
+ >>> agent = agent.best()
498
+
499
+ .. note::
500
+ The best model is the model with the highest evaluation return. If this method is called, the current \
501
+ model will be replaced by the best model.
502
+ """
503
+
504
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
505
+ # Load best model if it exists
506
+ if os.path.exists(best_model_file_path):
507
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
508
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
509
+ return self
DI-engine/ding/bonus/sac.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11
+ OffPolicyLearner, final_ctx_saver
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import SACPolicy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import ContinuousQAC
19
+ from ding.model import model_wrap
20
+ from ding.data import DequeBuffer
21
+ from ding.bonus.common import TrainingReturn, EvalReturn
22
+ from ding.config.example.SAC import supported_env_cfg
23
+ from ding.config.example.SAC import supported_env
24
+
25
+
26
+ class SACAgent:
27
+ """
28
+ Overview:
29
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
30
+ Soft Actor-Critic(SAC).
31
+ For more information about the system design of RL agent, please refer to \
32
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
33
+ Interface:
34
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
35
+ """
36
+ supported_env_list = list(supported_env_cfg.keys())
37
+ """
38
+ Overview:
39
+ List of supported envs.
40
+ Examples:
41
+ >>> from ding.bonus.sac import SACAgent
42
+ >>> print(SACAgent.supported_env_list)
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ env_id: str = None,
48
+ env: BaseEnv = None,
49
+ seed: int = 0,
50
+ exp_name: str = None,
51
+ model: Optional[torch.nn.Module] = None,
52
+ cfg: Optional[Union[EasyDict, dict]] = None,
53
+ policy_state_dict: str = None,
54
+ ) -> None:
55
+ """
56
+ Overview:
57
+ Initialize agent for SAC algorithm.
58
+ Arguments:
59
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
60
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
61
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
62
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
63
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
64
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
65
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
66
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
67
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
68
+ Default to 0.
69
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
70
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
71
+ - model (:obj:`torch.nn.Module`): The model of SAC algorithm, which should be an instance of class \
72
+ :class:`ding.model.ContinuousQAC`. \
73
+ If not specified, a default model will be generated according to the configuration.
74
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of SAC algorithm, which is a dict. \
75
+ Default to None. If not specified, the default configuration will be used. \
76
+ The default configuration can be found in ``ding/config/example/SAC/gym_lunarlander_v2.py``.
77
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
78
+ If specified, the policy will be loaded from this file. Default to None.
79
+
80
+ .. note::
81
+ An RL Agent Instance can be initialized in two basic ways. \
82
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
83
+ and we want to train an agent with SAC algorithm with default configuration. \
84
+ Then we can initialize the agent in the following ways:
85
+ >>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
86
+ or, if we want can specify the env_id in the configuration:
87
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
88
+ >>> agent = SACAgent(cfg=cfg)
89
+ There are also other arguments to specify the agent when initializing.
90
+ For example, if we want to specify the environment instance:
91
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
92
+ >>> agent = SACAgent(cfg=cfg, env=env)
93
+ or, if we want to specify the model:
94
+ >>> model = ContinuousQAC(**cfg.policy.model)
95
+ >>> agent = SACAgent(cfg=cfg, model=model)
96
+ or, if we want to reload the policy from a saved policy state dict:
97
+ >>> agent = SACAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
98
+ Make sure that the configuration is consistent with the saved policy state dict.
99
+ """
100
+
101
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
102
+
103
+ if cfg is not None and not isinstance(cfg, EasyDict):
104
+ cfg = EasyDict(cfg)
105
+
106
+ if env_id is not None:
107
+ assert env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
108
+ SACAgent.supported_env_list
109
+ )
110
+ if cfg is None:
111
+ cfg = supported_env_cfg[env_id]
112
+ else:
113
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
114
+ else:
115
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
116
+ assert cfg.env.env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
117
+ SACAgent.supported_env_list
118
+ )
119
+ default_policy_config = EasyDict({"policy": SACPolicy.default_config()})
120
+ default_policy_config.update(cfg)
121
+ cfg = default_policy_config
122
+
123
+ if exp_name is not None:
124
+ cfg.exp_name = exp_name
125
+ self.cfg = compile_config(cfg, policy=SACPolicy)
126
+ self.exp_name = self.cfg.exp_name
127
+ if env is None:
128
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
129
+ else:
130
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
131
+ self.env = env
132
+
133
+ logging.getLogger().setLevel(logging.INFO)
134
+ self.seed = seed
135
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
136
+ if not os.path.exists(self.exp_name):
137
+ os.makedirs(self.exp_name)
138
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
139
+ if model is None:
140
+ model = ContinuousQAC(**self.cfg.policy.model)
141
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
142
+ self.policy = SACPolicy(self.cfg.policy, model=model)
143
+ if policy_state_dict is not None:
144
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
145
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
146
+
147
+ def train(
148
+ self,
149
+ step: int = int(1e7),
150
+ collector_env_num: int = None,
151
+ evaluator_env_num: int = None,
152
+ n_iter_save_ckpt: int = 1000,
153
+ context: Optional[str] = None,
154
+ debug: bool = False,
155
+ wandb_sweep: bool = False,
156
+ ) -> TrainingReturn:
157
+ """
158
+ Overview:
159
+ Train the agent with SAC algorithm for ``step`` iterations with ``collector_env_num`` collector \
160
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
161
+ recorded and saved by wandb.
162
+ Arguments:
163
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
164
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
165
+ If not specified, it will be set according to the configuration.
166
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
167
+ If not specified, it will be set according to the configuration.
168
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
169
+ Default to 1000.
170
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
171
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
172
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
173
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
174
+ subprocess environment manager will be used.
175
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
176
+ which is a hyper-parameter optimization process for seeking the best configurations. \
177
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
178
+ Returns:
179
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
180
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
181
+ """
182
+
183
+ if debug:
184
+ logging.getLogger().setLevel(logging.DEBUG)
185
+ logging.debug(self.policy._model)
186
+ # define env and policy
187
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
188
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
189
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
190
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
191
+
192
+ with task.start(ctx=OnlineRLContext()):
193
+ task.use(
194
+ interaction_evaluator(
195
+ self.cfg,
196
+ self.policy.eval_mode,
197
+ evaluator_env,
198
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
199
+ )
200
+ )
201
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
202
+ task.use(
203
+ StepCollector(
204
+ self.cfg,
205
+ self.policy.collect_mode,
206
+ collector_env,
207
+ random_collect_size=self.cfg.policy.random_collect_size
208
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
209
+ )
210
+ )
211
+ task.use(data_pusher(self.cfg, self.buffer_))
212
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
213
+ task.use(
214
+ wandb_online_logger(
215
+ metric_list=self.policy._monitor_vars_learn(),
216
+ model=self.policy._model,
217
+ anonymous=True,
218
+ project_name=self.exp_name,
219
+ wandb_sweep=wandb_sweep,
220
+ )
221
+ )
222
+ task.use(termination_checker(max_env_step=step))
223
+ task.use(final_ctx_saver(name=self.exp_name))
224
+ task.run()
225
+
226
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
227
+
228
+ def deploy(
229
+ self,
230
+ enable_save_replay: bool = False,
231
+ concatenate_all_replay: bool = False,
232
+ replay_save_path: str = None,
233
+ seed: Optional[Union[int, List]] = None,
234
+ debug: bool = False
235
+ ) -> EvalReturn:
236
+ """
237
+ Overview:
238
+ Deploy the agent with SAC algorithm by interacting with the environment, during which the replay video \
239
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
240
+ Arguments:
241
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
242
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
243
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
244
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
245
+ the replay video of each episode will be saved separately.
246
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
247
+ If not specified, the video will be saved in ``exp_name/videos``.
248
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
249
+ Default to None. If not specified, ``self.seed`` will be used. \
250
+ If ``seed`` is an integer, the agent will be deployed once. \
251
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
252
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
253
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
254
+ subprocess environment manager will be used.
255
+ Returns:
256
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
257
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
258
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
259
+ """
260
+
261
+ if debug:
262
+ logging.getLogger().setLevel(logging.DEBUG)
263
+ # define env and policy
264
+ env = self.env.clone(caller='evaluator')
265
+
266
+ if seed is not None and isinstance(seed, int):
267
+ seeds = [seed]
268
+ elif seed is not None and isinstance(seed, list):
269
+ seeds = seed
270
+ else:
271
+ seeds = [self.seed]
272
+
273
+ returns = []
274
+ images = []
275
+ if enable_save_replay:
276
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
277
+ env.enable_save_replay(replay_path=replay_save_path)
278
+ else:
279
+ logging.warning('No video would be generated during the deploy.')
280
+ if concatenate_all_replay:
281
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
282
+ concatenate_all_replay = False
283
+
284
+ def single_env_forward_wrapper(forward_fn, cuda=True):
285
+
286
+ forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
287
+
288
+ def _forward(obs):
289
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
290
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
291
+ if cuda and torch.cuda.is_available():
292
+ obs = obs.cuda()
293
+ (mu, sigma) = forward_fn(obs, mode='compute_actor')['logit']
294
+ action = torch.tanh(mu).detach().cpu().numpy()[0] # deterministic_eval
295
+ return action
296
+
297
+ return _forward
298
+
299
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
300
+
301
+ # reset first to make sure the env is in the initial state
302
+ # env will be reset again in the main loop
303
+ env.reset()
304
+
305
+ for seed in seeds:
306
+ env.seed(seed, dynamic_seed=False)
307
+ return_ = 0.
308
+ step = 0
309
+ obs = env.reset()
310
+ images.append(render(env)[None]) if concatenate_all_replay else None
311
+ while True:
312
+ action = forward_fn(obs)
313
+ obs, rew, done, info = env.step(action)
314
+ images.append(render(env)[None]) if concatenate_all_replay else None
315
+ return_ += rew
316
+ step += 1
317
+ if done:
318
+ break
319
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
320
+ returns.append(return_)
321
+
322
+ env.close()
323
+
324
+ if concatenate_all_replay:
325
+ images = np.concatenate(images, axis=0)
326
+ import imageio
327
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
328
+
329
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
330
+
331
+ def collect_data(
332
+ self,
333
+ env_num: int = 8,
334
+ save_data_path: Optional[str] = None,
335
+ n_sample: Optional[int] = None,
336
+ n_episode: Optional[int] = None,
337
+ context: Optional[str] = None,
338
+ debug: bool = False
339
+ ) -> None:
340
+ """
341
+ Overview:
342
+ Collect data with SAC algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
343
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
344
+ ``exp_name/demo_data``.
345
+ Arguments:
346
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
347
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
348
+ If not specified, the data will be saved in ``exp_name/demo_data``.
349
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
350
+ If not specified, ``n_episode`` must be specified.
351
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
352
+ If not specified, ``n_sample`` must be specified.
353
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
354
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
355
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
356
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
357
+ subprocess environment manager will be used.
358
+ """
359
+
360
+ if debug:
361
+ logging.getLogger().setLevel(logging.DEBUG)
362
+ if n_episode is not None:
363
+ raise NotImplementedError
364
+ # define env and policy
365
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
366
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
367
+
368
+ if save_data_path is None:
369
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
370
+
371
+ # main execution task
372
+ with task.start(ctx=OnlineRLContext()):
373
+ task.use(
374
+ StepCollector(
375
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
376
+ )
377
+ )
378
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
379
+ task.run(max_step=1)
380
+ logging.info(
381
+ f'SAC collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
382
+ )
383
+
384
+ def batch_evaluate(
385
+ self,
386
+ env_num: int = 4,
387
+ n_evaluator_episode: int = 4,
388
+ context: Optional[str] = None,
389
+ debug: bool = False
390
+ ) -> EvalReturn:
391
+ """
392
+ Overview:
393
+ Evaluate the agent with SAC algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
394
+ environments. The evaluation result will be returned.
395
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
396
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
397
+ will only create one evaluator environment to evaluate the agent and save the replay video.
398
+ Arguments:
399
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
400
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
401
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
402
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
403
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
404
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
405
+ subprocess environment manager will be used.
406
+ Returns:
407
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
408
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
409
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
410
+ """
411
+
412
+ if debug:
413
+ logging.getLogger().setLevel(logging.DEBUG)
414
+ # define env and policy
415
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
416
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
417
+
418
+ # reset first to make sure the env is in the initial state
419
+ # env will be reset again in the main loop
420
+ env.launch()
421
+ env.reset()
422
+
423
+ evaluate_cfg = self.cfg
424
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
425
+
426
+ # main execution task
427
+ with task.start(ctx=OnlineRLContext()):
428
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
429
+ task.run(max_step=1)
430
+
431
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
432
+
433
+ @property
434
+ def best(self) -> 'SACAgent':
435
+ """
436
+ Overview:
437
+ Load the best model from the checkpoint directory, \
438
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
439
+ The return value is the agent with the best model.
440
+ Returns:
441
+ - (:obj:`SACAgent`): The agent with the best model.
442
+ Examples:
443
+ >>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
444
+ >>> agent.train()
445
+ >>> agent = agent.best
446
+
447
+ .. note::
448
+ The best model is the model with the highest evaluation return. If this method is called, the current \
449
+ model will be replaced by the best model.
450
+ """
451
+
452
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
453
+ # Load best model if it exists
454
+ if os.path.exists(best_model_file_path):
455
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
456
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
457
+ return self
DI-engine/ding/bonus/sql.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11
+ OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import SQLPolicy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import DQN
19
+ from ding.model import model_wrap
20
+ from ding.data import DequeBuffer
21
+ from ding.bonus.common import TrainingReturn, EvalReturn
22
+ from ding.config.example.SQL import supported_env_cfg
23
+ from ding.config.example.SQL import supported_env
24
+
25
+
26
+ class SQLAgent:
27
+ """
28
+ Overview:
29
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
30
+ Soft Q-Learning(SQL).
31
+ For more information about the system design of RL agent, please refer to \
32
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
33
+ Interface:
34
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
35
+ """
36
+ supported_env_list = list(supported_env_cfg.keys())
37
+ """
38
+ Overview:
39
+ List of supported envs.
40
+ Examples:
41
+ >>> from ding.bonus.sql import SQLAgent
42
+ >>> print(SQLAgent.supported_env_list)
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ env_id: str = None,
48
+ env: BaseEnv = None,
49
+ seed: int = 0,
50
+ exp_name: str = None,
51
+ model: Optional[torch.nn.Module] = None,
52
+ cfg: Optional[Union[EasyDict, dict]] = None,
53
+ policy_state_dict: str = None,
54
+ ) -> None:
55
+ """
56
+ Overview:
57
+ Initialize agent for SQL algorithm.
58
+ Arguments:
59
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
60
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
61
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
62
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
63
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
64
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
65
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
66
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
67
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
68
+ Default to 0.
69
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
70
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
71
+ - model (:obj:`torch.nn.Module`): The model of SQL algorithm, which should be an instance of class \
72
+ :class:`ding.model.DQN`. \
73
+ If not specified, a default model will be generated according to the configuration.
74
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of SQL algorithm, which is a dict. \
75
+ Default to None. If not specified, the default configuration will be used. \
76
+ The default configuration can be found in ``ding/config/example/SQL/gym_lunarlander_v2.py``.
77
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
78
+ If specified, the policy will be loaded from this file. Default to None.
79
+
80
+ .. note::
81
+ An RL Agent Instance can be initialized in two basic ways. \
82
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
83
+ and we want to train an agent with SQL algorithm with default configuration. \
84
+ Then we can initialize the agent in the following ways:
85
+ >>> agent = SQLAgent(env_id='LunarLander-v2')
86
+ or, if we want can specify the env_id in the configuration:
87
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
88
+ >>> agent = SQLAgent(cfg=cfg)
89
+ There are also other arguments to specify the agent when initializing.
90
+ For example, if we want to specify the environment instance:
91
+ >>> env = CustomizedEnv('LunarLander-v2')
92
+ >>> agent = SQLAgent(cfg=cfg, env=env)
93
+ or, if we want to specify the model:
94
+ >>> model = DQN(**cfg.policy.model)
95
+ >>> agent = SQLAgent(cfg=cfg, model=model)
96
+ or, if we want to reload the policy from a saved policy state dict:
97
+ >>> agent = SQLAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
98
+ Make sure that the configuration is consistent with the saved policy state dict.
99
+ """
100
+
101
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
102
+
103
+ if cfg is not None and not isinstance(cfg, EasyDict):
104
+ cfg = EasyDict(cfg)
105
+
106
+ if env_id is not None:
107
+ assert env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
108
+ SQLAgent.supported_env_list
109
+ )
110
+ if cfg is None:
111
+ cfg = supported_env_cfg[env_id]
112
+ else:
113
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
114
+ else:
115
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
116
+ assert cfg.env.env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
117
+ SQLAgent.supported_env_list
118
+ )
119
+ default_policy_config = EasyDict({"policy": SQLPolicy.default_config()})
120
+ default_policy_config.update(cfg)
121
+ cfg = default_policy_config
122
+
123
+ if exp_name is not None:
124
+ cfg.exp_name = exp_name
125
+ self.cfg = compile_config(cfg, policy=SQLPolicy)
126
+ self.exp_name = self.cfg.exp_name
127
+ if env is None:
128
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
129
+ else:
130
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
131
+ self.env = env
132
+
133
+ logging.getLogger().setLevel(logging.INFO)
134
+ self.seed = seed
135
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
136
+ if not os.path.exists(self.exp_name):
137
+ os.makedirs(self.exp_name)
138
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
139
+ if model is None:
140
+ model = DQN(**self.cfg.policy.model)
141
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
142
+ self.policy = SQLPolicy(self.cfg.policy, model=model)
143
+ if policy_state_dict is not None:
144
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
145
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
146
+
147
+ def train(
148
+ self,
149
+ step: int = int(1e7),
150
+ collector_env_num: int = None,
151
+ evaluator_env_num: int = None,
152
+ n_iter_save_ckpt: int = 1000,
153
+ context: Optional[str] = None,
154
+ debug: bool = False,
155
+ wandb_sweep: bool = False,
156
+ ) -> TrainingReturn:
157
+ """
158
+ Overview:
159
+ Train the agent with SQL algorithm for ``step`` iterations with ``collector_env_num`` collector \
160
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
161
+ recorded and saved by wandb.
162
+ Arguments:
163
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
164
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
165
+ If not specified, it will be set according to the configuration.
166
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
167
+ If not specified, it will be set according to the configuration.
168
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
169
+ Default to 1000.
170
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
171
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
172
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
173
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
174
+ subprocess environment manager will be used.
175
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
176
+ which is a hyper-parameter optimization process for seeking the best configurations. \
177
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
178
+ Returns:
179
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
180
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
181
+ """
182
+
183
+ if debug:
184
+ logging.getLogger().setLevel(logging.DEBUG)
185
+ logging.debug(self.policy._model)
186
+ # define env and policy
187
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
188
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
189
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
190
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
191
+
192
+ with task.start(ctx=OnlineRLContext()):
193
+ task.use(
194
+ interaction_evaluator(
195
+ self.cfg,
196
+ self.policy.eval_mode,
197
+ evaluator_env,
198
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
199
+ )
200
+ )
201
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
202
+ task.use(eps_greedy_handler(self.cfg))
203
+ task.use(
204
+ StepCollector(
205
+ self.cfg,
206
+ self.policy.collect_mode,
207
+ collector_env,
208
+ random_collect_size=self.cfg.policy.random_collect_size
209
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
210
+ )
211
+ )
212
+ if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
213
+ task.use(nstep_reward_enhancer(self.cfg))
214
+ task.use(data_pusher(self.cfg, self.buffer_))
215
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
216
+ task.use(
217
+ wandb_online_logger(
218
+ metric_list=self.policy._monitor_vars_learn(),
219
+ model=self.policy._model,
220
+ anonymous=True,
221
+ project_name=self.exp_name,
222
+ wandb_sweep=wandb_sweep,
223
+ )
224
+ )
225
+ task.use(termination_checker(max_env_step=step))
226
+ task.use(final_ctx_saver(name=self.exp_name))
227
+ task.run()
228
+
229
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
230
+
231
+ def deploy(
232
+ self,
233
+ enable_save_replay: bool = False,
234
+ concatenate_all_replay: bool = False,
235
+ replay_save_path: str = None,
236
+ seed: Optional[Union[int, List]] = None,
237
+ debug: bool = False
238
+ ) -> EvalReturn:
239
+ """
240
+ Overview:
241
+ Deploy the agent with SQL algorithm by interacting with the environment, during which the replay video \
242
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
243
+ Arguments:
244
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
245
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
246
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
247
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
248
+ the replay video of each episode will be saved separately.
249
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
250
+ If not specified, the video will be saved in ``exp_name/videos``.
251
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
252
+ Default to None. If not specified, ``self.seed`` will be used. \
253
+ If ``seed`` is an integer, the agent will be deployed once. \
254
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
255
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
256
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
257
+ subprocess environment manager will be used.
258
+ Returns:
259
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
260
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
261
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
262
+ """
263
+
264
+ if debug:
265
+ logging.getLogger().setLevel(logging.DEBUG)
266
+ # define env and policy
267
+ env = self.env.clone(caller='evaluator')
268
+
269
+ if seed is not None and isinstance(seed, int):
270
+ seeds = [seed]
271
+ elif seed is not None and isinstance(seed, list):
272
+ seeds = seed
273
+ else:
274
+ seeds = [self.seed]
275
+
276
+ returns = []
277
+ images = []
278
+ if enable_save_replay:
279
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
280
+ env.enable_save_replay(replay_path=replay_save_path)
281
+ else:
282
+ logging.warning('No video would be generated during the deploy.')
283
+ if concatenate_all_replay:
284
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
285
+ concatenate_all_replay = False
286
+
287
+ def single_env_forward_wrapper(forward_fn, cuda=True):
288
+
289
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
290
+
291
+ def _forward(obs):
292
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
293
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
294
+ if cuda and torch.cuda.is_available():
295
+ obs = obs.cuda()
296
+ action = forward_fn(obs)["action"]
297
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
298
+ action = action.squeeze(0).detach().cpu().numpy()
299
+ return action
300
+
301
+ return _forward
302
+
303
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
304
+
305
+ # reset first to make sure the env is in the initial state
306
+ # env will be reset again in the main loop
307
+ env.reset()
308
+
309
+ for seed in seeds:
310
+ env.seed(seed, dynamic_seed=False)
311
+ return_ = 0.
312
+ step = 0
313
+ obs = env.reset()
314
+ images.append(render(env)[None]) if concatenate_all_replay else None
315
+ while True:
316
+ action = forward_fn(obs)
317
+ obs, rew, done, info = env.step(action)
318
+ images.append(render(env)[None]) if concatenate_all_replay else None
319
+ return_ += rew
320
+ step += 1
321
+ if done:
322
+ break
323
+ logging.info(f'SQL deploy is finished, final episode return with {step} steps is: {return_}')
324
+ returns.append(return_)
325
+
326
+ env.close()
327
+
328
+ if concatenate_all_replay:
329
+ images = np.concatenate(images, axis=0)
330
+ import imageio
331
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
332
+
333
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
334
+
335
+ def collect_data(
336
+ self,
337
+ env_num: int = 8,
338
+ save_data_path: Optional[str] = None,
339
+ n_sample: Optional[int] = None,
340
+ n_episode: Optional[int] = None,
341
+ context: Optional[str] = None,
342
+ debug: bool = False
343
+ ) -> None:
344
+ """
345
+ Overview:
346
+ Collect data with SQL algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
347
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
348
+ ``exp_name/demo_data``.
349
+ Arguments:
350
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
351
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
352
+ If not specified, the data will be saved in ``exp_name/demo_data``.
353
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
354
+ If not specified, ``n_episode`` must be specified.
355
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
356
+ If not specified, ``n_sample`` must be specified.
357
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
358
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
359
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
360
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
361
+ subprocess environment manager will be used.
362
+ """
363
+
364
+ if debug:
365
+ logging.getLogger().setLevel(logging.DEBUG)
366
+ if n_episode is not None:
367
+ raise NotImplementedError
368
+ # define env and policy
369
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
370
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
371
+
372
+ if save_data_path is None:
373
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
374
+
375
+ # main execution task
376
+ with task.start(ctx=OnlineRLContext()):
377
+ task.use(
378
+ StepCollector(
379
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
380
+ )
381
+ )
382
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
383
+ task.run(max_step=1)
384
+ logging.info(
385
+ f'SQL collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
386
+ )
387
+
388
+ def batch_evaluate(
389
+ self,
390
+ env_num: int = 4,
391
+ n_evaluator_episode: int = 4,
392
+ context: Optional[str] = None,
393
+ debug: bool = False
394
+ ) -> EvalReturn:
395
+ """
396
+ Overview:
397
+ Evaluate the agent with SQL algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
398
+ environments. The evaluation result will be returned.
399
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
400
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
401
+ will only create one evaluator environment to evaluate the agent and save the replay video.
402
+ Arguments:
403
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
404
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
405
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
406
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
407
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
408
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
409
+ subprocess environment manager will be used.
410
+ Returns:
411
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
412
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
413
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
414
+ """
415
+
416
+ if debug:
417
+ logging.getLogger().setLevel(logging.DEBUG)
418
+ # define env and policy
419
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
420
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
421
+
422
+ # reset first to make sure the env is in the initial state
423
+ # env will be reset again in the main loop
424
+ env.launch()
425
+ env.reset()
426
+
427
+ evaluate_cfg = self.cfg
428
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
429
+
430
+ # main execution task
431
+ with task.start(ctx=OnlineRLContext()):
432
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
433
+ task.run(max_step=1)
434
+
435
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
436
+
437
+ @property
438
+ def best(self) -> 'SQLAgent':
439
+ """
440
+ Overview:
441
+ Load the best model from the checkpoint directory, \
442
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
443
+ The return value is the agent with the best model.
444
+ Returns:
445
+ - (:obj:`SQLAgent`): The agent with the best model.
446
+ Examples:
447
+ >>> agent = SQLAgent(env_id='LunarLander-v2')
448
+ >>> agent.train()
449
+ >>> agent = agent.best
450
+
451
+ .. note::
452
+ The best model is the model with the highest evaluation return. If this method is called, the current \
453
+ model will be replaced by the best model.
454
+ """
455
+
456
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
457
+ # Load best model if it exists
458
+ if os.path.exists(best_model_file_path):
459
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
460
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
461
+ return self
DI-engine/ding/bonus/td3.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ from ditk import logging
3
+ from easydict import EasyDict
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ import treetensor.torch as ttorch
8
+ from ding.framework import task, OnlineRLContext
9
+ from ding.framework.middleware import CkptSaver, \
10
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
11
+ OffPolicyLearner, final_ctx_saver
12
+ from ding.envs import BaseEnv
13
+ from ding.envs import setup_ding_env_manager
14
+ from ding.policy import TD3Policy
15
+ from ding.utils import set_pkg_seed
16
+ from ding.utils import get_env_fps, render
17
+ from ding.config import save_config_py, compile_config
18
+ from ding.model import ContinuousQAC
19
+ from ding.data import DequeBuffer
20
+ from ding.bonus.common import TrainingReturn, EvalReturn
21
+ from ding.config.example.TD3 import supported_env_cfg
22
+ from ding.config.example.TD3 import supported_env
23
+
24
+
25
+ class TD3Agent:
26
+ """
27
+ Overview:
28
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
29
+ Twin Delayed Deep Deterministic Policy Gradient(TD3).
30
+ For more information about the system design of RL agent, please refer to \
31
+ <https://di-engine-docs.readthedocs.io/en/latest/03_system/agent.html>.
32
+ Interface:
33
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
34
+ """
35
+ supported_env_list = list(supported_env_cfg.keys())
36
+ """
37
+ Overview:
38
+ List of supported envs.
39
+ Examples:
40
+ >>> from ding.bonus.td3 import TD3Agent
41
+ >>> print(TD3Agent.supported_env_list)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ env_id: str = None,
47
+ env: BaseEnv = None,
48
+ seed: int = 0,
49
+ exp_name: str = None,
50
+ model: Optional[torch.nn.Module] = None,
51
+ cfg: Optional[Union[EasyDict, dict]] = None,
52
+ policy_state_dict: str = None,
53
+ ) -> None:
54
+ """
55
+ Overview:
56
+ Initialize agent for TD3 algorithm.
57
+ Arguments:
58
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
59
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
60
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
61
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
62
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
63
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
64
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
65
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
66
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
67
+ Default to 0.
68
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
69
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
70
+ - model (:obj:`torch.nn.Module`): The model of TD3 algorithm, which should be an instance of class \
71
+ :class:`ding.model.ContinuousQAC`. \
72
+ If not specified, a default model will be generated according to the configuration.
73
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of TD3 algorithm, which is a dict. \
74
+ Default to None. If not specified, the default configuration will be used. \
75
+ The default configuration can be found in ``ding/config/example/TD3/gym_lunarlander_v2.py``.
76
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
77
+ If specified, the policy will be loaded from this file. Default to None.
78
+
79
+ .. note::
80
+ An RL Agent Instance can be initialized in two basic ways. \
81
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
82
+ and we want to train an agent with TD3 algorithm with default configuration. \
83
+ Then we can initialize the agent in the following ways:
84
+ >>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
85
+ or, if we want can specify the env_id in the configuration:
86
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
87
+ >>> agent = TD3Agent(cfg=cfg)
88
+ There are also other arguments to specify the agent when initializing.
89
+ For example, if we want to specify the environment instance:
90
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
91
+ >>> agent = TD3Agent(cfg=cfg, env=env)
92
+ or, if we want to specify the model:
93
+ >>> model = ContinuousQAC(**cfg.policy.model)
94
+ >>> agent = TD3Agent(cfg=cfg, model=model)
95
+ or, if we want to reload the policy from a saved policy state dict:
96
+ >>> agent = TD3Agent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
97
+ Make sure that the configuration is consistent with the saved policy state dict.
98
+ """
99
+
100
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
101
+
102
+ if cfg is not None and not isinstance(cfg, EasyDict):
103
+ cfg = EasyDict(cfg)
104
+
105
+ if env_id is not None:
106
+ assert env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
107
+ TD3Agent.supported_env_list
108
+ )
109
+ if cfg is None:
110
+ cfg = supported_env_cfg[env_id]
111
+ else:
112
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
113
+ else:
114
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
115
+ assert cfg.env.env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
116
+ TD3Agent.supported_env_list
117
+ )
118
+ default_policy_config = EasyDict({"policy": TD3Policy.default_config()})
119
+ default_policy_config.update(cfg)
120
+ cfg = default_policy_config
121
+
122
+ if exp_name is not None:
123
+ cfg.exp_name = exp_name
124
+ self.cfg = compile_config(cfg, policy=TD3Policy)
125
+ self.exp_name = self.cfg.exp_name
126
+ if env is None:
127
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
128
+ else:
129
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
130
+ self.env = env
131
+
132
+ logging.getLogger().setLevel(logging.INFO)
133
+ self.seed = seed
134
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
135
+ if not os.path.exists(self.exp_name):
136
+ os.makedirs(self.exp_name)
137
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
138
+ if model is None:
139
+ model = ContinuousQAC(**self.cfg.policy.model)
140
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
141
+ self.policy = TD3Policy(self.cfg.policy, model=model)
142
+ if policy_state_dict is not None:
143
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
144
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
145
+
146
+ def train(
147
+ self,
148
+ step: int = int(1e7),
149
+ collector_env_num: int = None,
150
+ evaluator_env_num: int = None,
151
+ n_iter_save_ckpt: int = 1000,
152
+ context: Optional[str] = None,
153
+ debug: bool = False,
154
+ wandb_sweep: bool = False,
155
+ ) -> TrainingReturn:
156
+ """
157
+ Overview:
158
+ Train the agent with TD3 algorithm for ``step`` iterations with ``collector_env_num`` collector \
159
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
160
+ recorded and saved by wandb.
161
+ Arguments:
162
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
163
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
164
+ If not specified, it will be set according to the configuration.
165
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
166
+ If not specified, it will be set according to the configuration.
167
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
168
+ Default to 1000.
169
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
170
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
171
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
172
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
173
+ subprocess environment manager will be used.
174
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
175
+ which is a hyper-parameter optimization process for seeking the best configurations. \
176
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
177
+ Returns:
178
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
179
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
180
+ """
181
+
182
+ if debug:
183
+ logging.getLogger().setLevel(logging.DEBUG)
184
+ logging.debug(self.policy._model)
185
+ # define env and policy
186
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
187
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
188
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
189
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
190
+
191
+ with task.start(ctx=OnlineRLContext()):
192
+ task.use(
193
+ interaction_evaluator(
194
+ self.cfg,
195
+ self.policy.eval_mode,
196
+ evaluator_env,
197
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
198
+ )
199
+ )
200
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
201
+ task.use(
202
+ StepCollector(
203
+ self.cfg,
204
+ self.policy.collect_mode,
205
+ collector_env,
206
+ random_collect_size=self.cfg.policy.random_collect_size
207
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
208
+ )
209
+ )
210
+ task.use(data_pusher(self.cfg, self.buffer_))
211
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
212
+ task.use(
213
+ wandb_online_logger(
214
+ metric_list=self.policy._monitor_vars_learn(),
215
+ model=self.policy._model,
216
+ anonymous=True,
217
+ project_name=self.exp_name,
218
+ wandb_sweep=wandb_sweep,
219
+ )
220
+ )
221
+ task.use(termination_checker(max_env_step=step))
222
+ task.use(final_ctx_saver(name=self.exp_name))
223
+ task.run()
224
+
225
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
226
+
227
+ def deploy(
228
+ self,
229
+ enable_save_replay: bool = False,
230
+ concatenate_all_replay: bool = False,
231
+ replay_save_path: str = None,
232
+ seed: Optional[Union[int, List]] = None,
233
+ debug: bool = False
234
+ ) -> EvalReturn:
235
+ """
236
+ Overview:
237
+ Deploy the agent with TD3 algorithm by interacting with the environment, during which the replay video \
238
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
239
+ Arguments:
240
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
241
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
242
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
243
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
244
+ the replay video of each episode will be saved separately.
245
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
246
+ If not specified, the video will be saved in ``exp_name/videos``.
247
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
248
+ Default to None. If not specified, ``self.seed`` will be used. \
249
+ If ``seed`` is an integer, the agent will be deployed once. \
250
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
251
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
252
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
253
+ subprocess environment manager will be used.
254
+ Returns:
255
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
256
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
257
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
258
+ """
259
+
260
+ if debug:
261
+ logging.getLogger().setLevel(logging.DEBUG)
262
+ # define env and policy
263
+ env = self.env.clone(caller='evaluator')
264
+
265
+ if seed is not None and isinstance(seed, int):
266
+ seeds = [seed]
267
+ elif seed is not None and isinstance(seed, list):
268
+ seeds = seed
269
+ else:
270
+ seeds = [self.seed]
271
+
272
+ returns = []
273
+ images = []
274
+ if enable_save_replay:
275
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
276
+ env.enable_save_replay(replay_path=replay_save_path)
277
+ else:
278
+ logging.warning('No video would be generated during the deploy.')
279
+ if concatenate_all_replay:
280
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
281
+ concatenate_all_replay = False
282
+
283
+ def single_env_forward_wrapper(forward_fn, cuda=True):
284
+
285
+ def _forward(obs):
286
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
287
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
288
+ if cuda and torch.cuda.is_available():
289
+ obs = obs.cuda()
290
+ action = forward_fn(obs, mode='compute_actor')["action"]
291
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
292
+ action = action.squeeze(0).detach().cpu().numpy()
293
+ return action
294
+
295
+ return _forward
296
+
297
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
298
+
299
+ # reset first to make sure the env is in the initial state
300
+ # env will be reset again in the main loop
301
+ env.reset()
302
+
303
+ for seed in seeds:
304
+ env.seed(seed, dynamic_seed=False)
305
+ return_ = 0.
306
+ step = 0
307
+ obs = env.reset()
308
+ images.append(render(env)[None]) if concatenate_all_replay else None
309
+ while True:
310
+ action = forward_fn(obs)
311
+ obs, rew, done, info = env.step(action)
312
+ images.append(render(env)[None]) if concatenate_all_replay else None
313
+ return_ += rew
314
+ step += 1
315
+ if done:
316
+ break
317
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
318
+ returns.append(return_)
319
+
320
+ env.close()
321
+
322
+ if concatenate_all_replay:
323
+ images = np.concatenate(images, axis=0)
324
+ import imageio
325
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
326
+
327
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
328
+
329
+ def collect_data(
330
+ self,
331
+ env_num: int = 8,
332
+ save_data_path: Optional[str] = None,
333
+ n_sample: Optional[int] = None,
334
+ n_episode: Optional[int] = None,
335
+ context: Optional[str] = None,
336
+ debug: bool = False
337
+ ) -> None:
338
+ """
339
+ Overview:
340
+ Collect data with TD3 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
341
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
342
+ ``exp_name/demo_data``.
343
+ Arguments:
344
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
345
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
346
+ If not specified, the data will be saved in ``exp_name/demo_data``.
347
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
348
+ If not specified, ``n_episode`` must be specified.
349
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
350
+ If not specified, ``n_sample`` must be specified.
351
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
352
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
353
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
354
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
355
+ subprocess environment manager will be used.
356
+ """
357
+
358
+ if debug:
359
+ logging.getLogger().setLevel(logging.DEBUG)
360
+ if n_episode is not None:
361
+ raise NotImplementedError
362
+ # define env and policy
363
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
364
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
365
+
366
+ if save_data_path is None:
367
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
368
+
369
+ # main execution task
370
+ with task.start(ctx=OnlineRLContext()):
371
+ task.use(
372
+ StepCollector(
373
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
374
+ )
375
+ )
376
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
377
+ task.run(max_step=1)
378
+ logging.info(
379
+ f'TD3 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
380
+ )
381
+
382
+ def batch_evaluate(
383
+ self,
384
+ env_num: int = 4,
385
+ n_evaluator_episode: int = 4,
386
+ context: Optional[str] = None,
387
+ debug: bool = False
388
+ ) -> EvalReturn:
389
+ """
390
+ Overview:
391
+ Evaluate the agent with TD3 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
392
+ environments. The evaluation result will be returned.
393
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
394
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
395
+ will only create one evaluator environment to evaluate the agent and save the replay video.
396
+ Arguments:
397
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
398
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
399
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
400
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
401
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
402
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
403
+ subprocess environment manager will be used.
404
+ Returns:
405
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
406
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
407
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
408
+ """
409
+
410
+ if debug:
411
+ logging.getLogger().setLevel(logging.DEBUG)
412
+ # define env and policy
413
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
414
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
415
+
416
+ # reset first to make sure the env is in the initial state
417
+ # env will be reset again in the main loop
418
+ env.launch()
419
+ env.reset()
420
+
421
+ evaluate_cfg = self.cfg
422
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
423
+
424
+ # main execution task
425
+ with task.start(ctx=OnlineRLContext()):
426
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
427
+ task.run(max_step=1)
428
+
429
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
430
+
431
+ @property
432
+ def best(self) -> 'TD3Agent':
433
+ """
434
+ Overview:
435
+ Load the best model from the checkpoint directory, \
436
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
437
+ The return value is the agent with the best model.
438
+ Returns:
439
+ - (:obj:`TD3Agent`): The agent with the best model.
440
+ Examples:
441
+ >>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
442
+ >>> agent.train()
443
+ >>> agent.best
444
+
445
+ .. note::
446
+ The best model is the model with the highest evaluation return. If this method is called, the current \
447
+ model will be replaced by the best model.
448
+ """
449
+
450
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
451
+ # Load best model if it exists
452
+ if os.path.exists(best_model_file_path):
453
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
454
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
455
+ return self
DI-engine/ding/compatibility.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def torch_ge_131():
5
+ return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
6
+
7
+
8
+ def torch_ge_180():
9
+ return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180
DI-engine/ding/config/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .config import Config, read_config, save_config, compile_config, compile_config_parallel, read_config_directly, \
2
+ read_config_with_system, save_config_py
3
+ from .utils import parallel_transform, parallel_transform_slurm
4
+ from .example import A2C, C51, DDPG, DQN, PG, PPOF, PPOOffPolicy, SAC, SQL, TD3
DI-engine/ding/config/config.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import yaml
4
+ import json
5
+ import shutil
6
+ import sys
7
+ import time
8
+ import tempfile
9
+ import subprocess
10
+ import datetime
11
+ from importlib import import_module
12
+ from typing import Optional, Tuple
13
+ from easydict import EasyDict
14
+ from copy import deepcopy
15
+
16
+ from ding.utils import deep_merge_dicts, get_rank
17
+ from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
18
+ from ding.policy import get_policy_cls
19
+ from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
20
+ AdvancedReplayBuffer, get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, \
21
+ get_serial_collector_cls, MetricSerialEvaluator, BattleInteractionSerialEvaluator
22
+ from ding.reward_model import get_reward_model_cls
23
+ from ding.world_model import get_world_model_cls
24
+ from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted
25
+
26
+
27
+ class Config(object):
28
+ r"""
29
+ Overview:
30
+ Base class for config.
31
+ Interface:
32
+ __init__, file_to_dict
33
+ Property:
34
+ cfg_dict
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ cfg_dict: Optional[dict] = None,
40
+ cfg_text: Optional[str] = None,
41
+ filename: Optional[str] = None
42
+ ) -> None:
43
+ """
44
+ Overview:
45
+ Init method. Create config including dict type config and text type config.
46
+ Arguments:
47
+ - cfg_dict (:obj:`Optional[dict]`): dict type config
48
+ - cfg_text (:obj:`Optional[str]`): text type config
49
+ - filename (:obj:`Optional[str]`): config file name
50
+ """
51
+ if cfg_dict is None:
52
+ cfg_dict = {}
53
+ if not isinstance(cfg_dict, dict):
54
+ raise TypeError("invalid type for cfg_dict: {}".format(type(cfg_dict)))
55
+ self._cfg_dict = cfg_dict
56
+ if cfg_text:
57
+ text = cfg_text
58
+ elif filename:
59
+ with open(filename, 'r') as f:
60
+ text = f.read()
61
+ else:
62
+ text = '.'
63
+ self._text = text
64
+ self._filename = filename
65
+
66
+ @staticmethod
67
+ def file_to_dict(filename: str) -> 'Config': # noqa
68
+ """
69
+ Overview:
70
+ Read config file and create config.
71
+ Arguments:
72
+ - filename (:obj:`Optional[str]`): config file name.
73
+ Returns:
74
+ - cfg_dict (:obj:`Config`): config class
75
+ """
76
+ cfg_dict, cfg_text = Config._file_to_dict(filename)
77
+ return Config(cfg_dict, cfg_text, filename=filename)
78
+
79
+ @staticmethod
80
+ def _file_to_dict(filename: str) -> Tuple[dict, str]:
81
+ """
82
+ Overview:
83
+ Read config file and convert the config file to dict type config and text type config.
84
+ Arguments:
85
+ - filename (:obj:`Optional[str]`): config file name.
86
+ Returns:
87
+ - cfg_dict (:obj:`Optional[dict]`): dict type config
88
+ - cfg_text (:obj:`Optional[str]`): text type config
89
+ """
90
+ filename = osp.abspath(osp.expanduser(filename))
91
+ # TODO check exist
92
+ # TODO check suffix
93
+ ext_name = osp.splitext(filename)[-1]
94
+ with tempfile.TemporaryDirectory() as temp_config_dir:
95
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=ext_name)
96
+ temp_config_name = osp.basename(temp_config_file.name)
97
+ temp_config_file.close()
98
+ shutil.copyfile(filename, temp_config_file.name)
99
+
100
+ temp_module_name = osp.splitext(temp_config_name)[0]
101
+ sys.path.insert(0, temp_config_dir)
102
+ # TODO validate py syntax
103
+ module = import_module(temp_module_name)
104
+ cfg_dict = {k: v for k, v in module.__dict__.items() if not k.startswith('_')}
105
+ del sys.modules[temp_module_name]
106
+ sys.path.pop(0)
107
+
108
+ cfg_text = filename + '\n'
109
+ with open(filename, 'r') as f:
110
+ cfg_text += f.read()
111
+
112
+ return cfg_dict, cfg_text
113
+
114
+ @property
115
+ def cfg_dict(self) -> dict:
116
+ return self._cfg_dict
117
+
118
+
119
+ def read_config_yaml(path: str) -> EasyDict:
120
+ """
121
+ Overview:
122
+ read configuration from path
123
+ Arguments:
124
+ - path (:obj:`str`): Path of source yaml
125
+ Returns:
126
+ - (:obj:`EasyDict`): Config data from this file with dict type
127
+ """
128
+ with open(path, "r") as f:
129
+ config_ = yaml.safe_load(f)
130
+
131
+ return EasyDict(config_)
132
+
133
+
134
+ def save_config_yaml(config_: dict, path: str) -> None:
135
+ """
136
+ Overview:
137
+ save configuration to path
138
+ Arguments:
139
+ - config (:obj:`dict`): Config dict
140
+ - path (:obj:`str`): Path of target yaml
141
+ """
142
+ config_string = json.dumps(config_)
143
+ with open(path, "w") as f:
144
+ yaml.safe_dump(json.loads(config_string), f)
145
+
146
+
147
+ def save_config_py(config_: dict, path: str) -> None:
148
+ """
149
+ Overview:
150
+ save configuration to python file
151
+ Arguments:
152
+ - config (:obj:`dict`): Config dict
153
+ - path (:obj:`str`): Path of target yaml
154
+ """
155
+ # config_string = json.dumps(config_, indent=4)
156
+ config_string = str(config_)
157
+ from yapf.yapflib.yapf_api import FormatCode
158
+ config_string, _ = FormatCode(config_string)
159
+ config_string = config_string.replace('inf,', 'float("inf"),')
160
+ with open(path, "w") as f:
161
+ f.write('exp_config = ' + config_string)
162
+
163
+
164
+ def read_config_directly(path: str) -> dict:
165
+ """
166
+ Overview:
167
+ Read configuration from a file path(now only support python file) and directly return results.
168
+ Arguments:
169
+ - path (:obj:`str`): Path of configuration file
170
+ Returns:
171
+ - cfg (:obj:`Tuple[dict, dict]`): Configuration dict.
172
+ """
173
+ suffix = path.split('.')[-1]
174
+ if suffix == 'py':
175
+ return Config.file_to_dict(path).cfg_dict
176
+ else:
177
+ raise KeyError("invalid config file suffix: {}".format(suffix))
178
+
179
+
180
+ def read_config(path: str) -> Tuple[dict, dict]:
181
+ """
182
+ Overview:
183
+ Read configuration from a file path(now only suport python file). And select some proper parts.
184
+ Arguments:
185
+ - path (:obj:`str`): Path of configuration file
186
+ Returns:
187
+ - cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config` and \
188
+ `create_cfg` two parts.
189
+ """
190
+ suffix = path.split('.')[-1]
191
+ if suffix == 'py':
192
+ cfg = Config.file_to_dict(path).cfg_dict
193
+ assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
194
+ assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
195
+ return cfg['main_config'], cfg['create_config']
196
+ else:
197
+ raise KeyError("invalid config file suffix: {}".format(suffix))
198
+
199
+
200
+ def read_config_with_system(path: str) -> Tuple[dict, dict, dict]:
201
+ """
202
+ Overview:
203
+ Read configuration from a file path(now only suport python file). And select some proper parts
204
+ Arguments:
205
+ - path (:obj:`str`): Path of configuration file
206
+ Returns:
207
+ - cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config`, \
208
+ `create_cfg` and `system_config` three parts.
209
+ """
210
+ suffix = path.split('.')[-1]
211
+ if suffix == 'py':
212
+ cfg = Config.file_to_dict(path).cfg_dict
213
+ assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
214
+ assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
215
+ assert "system_config" in cfg, "Please make sure a 'system_config' variable is declared in config python file!"
216
+ return cfg['main_config'], cfg['create_config'], cfg['system_config']
217
+ else:
218
+ raise KeyError("invalid config file suffix: {}".format(suffix))
219
+
220
+
221
+ def save_config(config_: dict, path: str, type_: str = 'py', save_formatted: bool = False) -> None:
222
+ """
223
+ Overview:
224
+ save configuration to python file or yaml file
225
+ Arguments:
226
+ - config (:obj:`dict`): Config dict
227
+ - path (:obj:`str`): Path of target yaml or target python file
228
+ - type (:obj:`str`): If type is ``yaml`` , save configuration to yaml file. If type is ``py`` , save\
229
+ configuration to python file.
230
+ - save_formatted (:obj:`bool`): If save_formatted is true, save formatted config to path.\
231
+ Formatted config can be read by serial_pipeline directly.
232
+ """
233
+ assert type_ in ['yaml', 'py'], type_
234
+ if type_ == 'yaml':
235
+ save_config_yaml(config_, path)
236
+ elif type_ == 'py':
237
+ save_config_py(config_, path)
238
+ if save_formatted:
239
+ formated_path = osp.join(osp.dirname(path), 'formatted_' + osp.basename(path))
240
+ save_config_formatted(config_, formated_path)
241
+
242
+
243
+ def compile_buffer_config(policy_cfg: EasyDict, user_cfg: EasyDict, buffer_cls: 'IBuffer') -> EasyDict: # noqa
244
+
245
+ def _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls):
246
+
247
+ if buffer_cls is None:
248
+ assert 'type' in policy_buffer_cfg, "please indicate buffer type in create_cfg"
249
+ buffer_cls = get_buffer_cls(policy_buffer_cfg)
250
+ buffer_cfg = deep_merge_dicts(buffer_cls.default_config(), policy_buffer_cfg)
251
+ buffer_cfg = deep_merge_dicts(buffer_cfg, user_buffer_cfg)
252
+ return buffer_cfg
253
+
254
+ policy_multi_buffer = policy_cfg.other.replay_buffer.get('multi_buffer', False)
255
+ user_multi_buffer = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('multi_buffer', False)
256
+ assert not user_multi_buffer or user_multi_buffer == policy_multi_buffer, "For multi_buffer, \
257
+ user_cfg({}) and policy_cfg({}) must be in accordance".format(user_multi_buffer, policy_multi_buffer)
258
+ multi_buffer = policy_multi_buffer
259
+ if not multi_buffer:
260
+ policy_buffer_cfg = policy_cfg.other.replay_buffer
261
+ user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {})
262
+ return _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls)
263
+ else:
264
+ return_cfg = EasyDict()
265
+ for buffer_name in policy_cfg.other.replay_buffer: # Only traverse keys in policy_cfg
266
+ if buffer_name == 'multi_buffer':
267
+ continue
268
+ policy_buffer_cfg = policy_cfg.other.replay_buffer[buffer_name]
269
+ user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('buffer_name', {})
270
+ if buffer_cls is None:
271
+ return_cfg[buffer_name] = _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, None)
272
+ else:
273
+ return_cfg[buffer_name] = _compile_buffer_config(
274
+ policy_buffer_cfg, user_buffer_cfg, buffer_cls[buffer_name]
275
+ )
276
+ return_cfg[buffer_name].name = buffer_name
277
+ return return_cfg
278
+
279
+
280
+ def compile_collector_config(
281
+ policy_cfg: EasyDict,
282
+ user_cfg: EasyDict,
283
+ collector_cls: 'ISerialCollector' # noqa
284
+ ) -> EasyDict:
285
+ policy_collector_cfg = policy_cfg.collect.collector
286
+ user_collector_cfg = user_cfg.policy.get('collect', {}).get('collector', {})
287
+ # step1: get collector class
288
+ # two cases: create cfg merged in policy_cfg, collector class, and class has higher priority
289
+ if collector_cls is None:
290
+ assert 'type' in policy_collector_cfg, "please indicate collector type in create_cfg"
291
+ # use type to get collector_cls
292
+ collector_cls = get_serial_collector_cls(policy_collector_cfg)
293
+ # step2: policy collector cfg merge to collector cfg
294
+ collector_cfg = deep_merge_dicts(collector_cls.default_config(), policy_collector_cfg)
295
+ # step3: user collector cfg merge to the step2 config
296
+ collector_cfg = deep_merge_dicts(collector_cfg, user_collector_cfg)
297
+
298
+ return collector_cfg
299
+
300
+
301
+ policy_config_template = dict(
302
+ model=dict(),
303
+ learn=dict(learner=dict()),
304
+ collect=dict(collector=dict()),
305
+ eval=dict(evaluator=dict()),
306
+ other=dict(replay_buffer=dict()),
307
+ )
308
+ policy_config_template = EasyDict(policy_config_template)
309
+ env_config_template = dict(manager=dict(), stop_value=int(1e10), n_evaluator_episode=4)
310
+ env_config_template = EasyDict(env_config_template)
311
+
312
+
313
+ def save_project_state(exp_name: str) -> None:
314
+
315
+ def _fn(cmd: str):
316
+ return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
317
+
318
+ if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
319
+ short_sha = _fn("git describe --always")
320
+ log = _fn("git log --stat -n 5")
321
+ diff = _fn("git diff")
322
+ with open(os.path.join(exp_name, "git_log.txt"), "w", encoding='utf-8') as f:
323
+ f.write(short_sha + '\n\n' + log)
324
+ with open(os.path.join(exp_name, "git_diff.txt"), "w", encoding='utf-8') as f:
325
+ f.write(diff)
326
+
327
+
328
+ def compile_config(
329
+ cfg: EasyDict,
330
+ env_manager: type = None,
331
+ policy: type = None,
332
+ learner: type = BaseLearner,
333
+ collector: type = None,
334
+ evaluator: type = InteractionSerialEvaluator,
335
+ buffer: type = None,
336
+ env: type = None,
337
+ reward_model: type = None,
338
+ world_model: type = None,
339
+ seed: int = 0,
340
+ auto: bool = False,
341
+ create_cfg: dict = None,
342
+ save_cfg: bool = True,
343
+ save_path: str = 'total_config.py',
344
+ renew_dir: bool = True,
345
+ ) -> EasyDict:
346
+ """
347
+ Overview:
348
+ Combine the input config information with other input information.
349
+ Compile config to make it easy to be called by other programs
350
+ Arguments:
351
+ - cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline
352
+ - env_manager (:obj:`type`): Env_manager class which is to be used in the following pipeline
353
+ - policy (:obj:`type`): Policy class which is to be used in the following pipeline
354
+ - learner (:obj:`type`): Input learner class, defaults to BaseLearner
355
+ - collector (:obj:`type`): Input collector class, defaults to BaseSerialCollector
356
+ - evaluator (:obj:`type`): Input evaluator class, defaults to InteractionSerialEvaluator
357
+ - buffer (:obj:`type`): Input buffer class, defaults to IBuffer
358
+ - env (:obj:`type`): Environment class which is to be used in the following pipeline
359
+ - reward_model (:obj:`type`): Reward model class which aims to offer various and valuable reward
360
+ - seed (:obj:`int`): Random number seed
361
+ - auto (:obj:`bool`): Compile create_config dict or not
362
+ - create_cfg (:obj:`dict`): Input create config dict
363
+ - save_cfg (:obj:`bool`): Save config or not
364
+ - save_path (:obj:`str`): Path of saving file
365
+ - renew_dir (:obj:`bool`): Whether to new a directory for saving config.
366
+ Returns:
367
+ - cfg (:obj:`EasyDict`): Config after compiling
368
+ """
369
+ cfg, create_cfg = deepcopy(cfg), deepcopy(create_cfg)
370
+ if auto:
371
+ assert create_cfg is not None
372
+ # for compatibility
373
+ if 'collector' not in create_cfg:
374
+ create_cfg.collector = EasyDict(dict(type='sample'))
375
+ if 'replay_buffer' not in create_cfg:
376
+ create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
377
+ buffer = AdvancedReplayBuffer
378
+ if env is None:
379
+ if 'env' in create_cfg:
380
+ env = get_env_cls(create_cfg.env)
381
+ else:
382
+ env = None
383
+ create_cfg.env = {'type': 'ding_env_wrapper_generated'}
384
+ if env_manager is None:
385
+ env_manager = get_env_manager_cls(create_cfg.env_manager)
386
+ if policy is None:
387
+ policy = get_policy_cls(create_cfg.policy)
388
+ if 'default_config' in dir(env):
389
+ env_config = env.default_config()
390
+ else:
391
+ env_config = EasyDict() # env does not have default_config
392
+ env_config = deep_merge_dicts(env_config_template, env_config)
393
+ env_config.update(create_cfg.env)
394
+ env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
395
+ env_config.manager.update(create_cfg.env_manager)
396
+ policy_config = policy.default_config()
397
+ policy_config = deep_merge_dicts(policy_config_template, policy_config)
398
+ policy_config.update(create_cfg.policy)
399
+ policy_config.collect.collector.update(create_cfg.collector)
400
+ if 'evaluator' in create_cfg:
401
+ policy_config.eval.evaluator.update(create_cfg.evaluator)
402
+ policy_config.other.replay_buffer.update(create_cfg.replay_buffer)
403
+
404
+ policy_config.other.commander = BaseSerialCommander.default_config()
405
+ if 'reward_model' in create_cfg:
406
+ reward_model = get_reward_model_cls(create_cfg.reward_model)
407
+ reward_model_config = reward_model.default_config()
408
+ else:
409
+ reward_model_config = EasyDict()
410
+ if 'world_model' in create_cfg:
411
+ world_model = get_world_model_cls(create_cfg.world_model)
412
+ world_model_config = world_model.default_config()
413
+ world_model_config.update(create_cfg.world_model)
414
+ else:
415
+ world_model_config = EasyDict()
416
+ else:
417
+ if 'default_config' in dir(env):
418
+ env_config = env.default_config()
419
+ else:
420
+ env_config = EasyDict() # env does not have default_config
421
+ env_config = deep_merge_dicts(env_config_template, env_config)
422
+ if env_manager is None:
423
+ env_manager = BaseEnvManager # for compatibility
424
+ env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
425
+ policy_config = policy.default_config()
426
+ policy_config = deep_merge_dicts(policy_config_template, policy_config)
427
+ if reward_model is None:
428
+ reward_model_config = EasyDict()
429
+ else:
430
+ reward_model_config = reward_model.default_config()
431
+ if world_model is None:
432
+ world_model_config = EasyDict()
433
+ else:
434
+ world_model_config = world_model.default_config()
435
+ world_model_config.update(create_cfg.world_model)
436
+ policy_config.learn.learner = deep_merge_dicts(
437
+ learner.default_config(),
438
+ policy_config.learn.learner,
439
+ )
440
+ if create_cfg is not None or collector is not None:
441
+ policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector)
442
+ if evaluator:
443
+ policy_config.eval.evaluator = deep_merge_dicts(
444
+ evaluator.default_config(),
445
+ policy_config.eval.evaluator,
446
+ )
447
+ if create_cfg is not None or buffer is not None:
448
+ policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer)
449
+ default_config = EasyDict({'env': env_config, 'policy': policy_config})
450
+ if len(reward_model_config) > 0:
451
+ default_config['reward_model'] = reward_model_config
452
+ if len(world_model_config) > 0:
453
+ default_config['world_model'] = world_model_config
454
+ cfg = deep_merge_dicts(default_config, cfg)
455
+ if 'unroll_len' in cfg.policy:
456
+ cfg.policy.collect.unroll_len = cfg.policy.unroll_len
457
+ cfg.seed = seed
458
+ # check important key in config
459
+ if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
460
+ cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
461
+ cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
462
+ if 'exp_name' not in cfg:
463
+ cfg.exp_name = 'default_experiment'
464
+ if save_cfg and get_rank() == 0:
465
+ if os.path.exists(cfg.exp_name) and renew_dir:
466
+ cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
467
+ try:
468
+ os.makedirs(cfg.exp_name)
469
+ except FileExistsError:
470
+ pass
471
+ save_project_state(cfg.exp_name)
472
+ save_path = os.path.join(cfg.exp_name, save_path)
473
+ save_config(cfg, save_path, save_formatted=True)
474
+ return cfg
475
+
476
+
477
+ def compile_config_parallel(
478
+ cfg: EasyDict,
479
+ create_cfg: EasyDict,
480
+ system_cfg: EasyDict,
481
+ seed: int = 0,
482
+ save_cfg: bool = True,
483
+ save_path: str = 'total_config.py',
484
+ platform: str = 'local',
485
+ coordinator_host: Optional[str] = None,
486
+ learner_host: Optional[str] = None,
487
+ collector_host: Optional[str] = None,
488
+ coordinator_port: Optional[int] = None,
489
+ learner_port: Optional[int] = None,
490
+ collector_port: Optional[int] = None,
491
+ ) -> EasyDict:
492
+ """
493
+ Overview:
494
+ Combine the input parallel mode configuration information with other input information. Compile config\
495
+ to make it easy to be called by other programs
496
+ Arguments:
497
+ - cfg (:obj:`EasyDict`): Input main config dict
498
+ - create_cfg (:obj:`dict`): Input create config dict, including type parameters, such as environment type
499
+ - system_cfg (:obj:`dict`): Input system config dict, including system parameters, such as file path,\
500
+ communication mode, use multiple GPUs or not
501
+ - seed (:obj:`int`): Random number seed
502
+ - save_cfg (:obj:`bool`): Save config or not
503
+ - save_path (:obj:`str`): Path of saving file
504
+ - platform (:obj:`str`): Where to run the program, 'local' or 'slurm'
505
+ - coordinator_host (:obj:`Optional[str]`): Input coordinator's host when platform is slurm
506
+ - learner_host (:obj:`Optional[str]`): Input learner's host when platform is slurm
507
+ - collector_host (:obj:`Optional[str]`): Input collector's host when platform is slurm
508
+ Returns:
509
+ - cfg (:obj:`EasyDict`): Config after compiling
510
+ """
511
+ # for compatibility
512
+ if 'replay_buffer' not in create_cfg:
513
+ create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
514
+ # env
515
+ env = get_env_cls(create_cfg.env)
516
+ if 'default_config' in dir(env):
517
+ env_config = env.default_config()
518
+ else:
519
+ env_config = EasyDict() # env does not have default_config
520
+ env_config = deep_merge_dicts(env_config_template, env_config)
521
+ env_config.update(create_cfg.env)
522
+
523
+ env_manager = get_env_manager_cls(create_cfg.env_manager)
524
+ env_config.manager = env_manager.default_config()
525
+ env_config.manager.update(create_cfg.env_manager)
526
+
527
+ # policy
528
+ policy = get_policy_cls(create_cfg.policy)
529
+ policy_config = policy.default_config()
530
+ policy_config = deep_merge_dicts(policy_config_template, policy_config)
531
+ cfg.policy.update(create_cfg.policy)
532
+
533
+ collector = get_parallel_collector_cls(create_cfg.collector)
534
+ policy_config.collect.collector = collector.default_config()
535
+ policy_config.collect.collector.update(create_cfg.collector)
536
+ policy_config.learn.learner = BaseLearner.default_config()
537
+ policy_config.learn.learner.update(create_cfg.learner)
538
+ commander = get_parallel_commander_cls(create_cfg.commander)
539
+ policy_config.other.commander = commander.default_config()
540
+ policy_config.other.commander.update(create_cfg.commander)
541
+ policy_config.other.replay_buffer.update(create_cfg.replay_buffer)
542
+ policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, None)
543
+
544
+ default_config = EasyDict({'env': env_config, 'policy': policy_config})
545
+ cfg = deep_merge_dicts(default_config, cfg)
546
+
547
+ cfg.policy.other.commander.path_policy = system_cfg.path_policy # league may use 'path_policy'
548
+
549
+ # system
550
+ for k in ['comm_learner', 'comm_collector']:
551
+ system_cfg[k] = create_cfg[k]
552
+ if platform == 'local':
553
+ cfg = parallel_transform(EasyDict({'main': cfg, 'system': system_cfg}))
554
+ elif platform == 'slurm':
555
+ cfg = parallel_transform_slurm(
556
+ EasyDict({
557
+ 'main': cfg,
558
+ 'system': system_cfg
559
+ }), coordinator_host, learner_host, collector_host
560
+ )
561
+ elif platform == 'k8s':
562
+ cfg = parallel_transform_k8s(
563
+ EasyDict({
564
+ 'main': cfg,
565
+ 'system': system_cfg
566
+ }),
567
+ coordinator_port=coordinator_port,
568
+ learner_port=learner_port,
569
+ collector_port=collector_port
570
+ )
571
+ else:
572
+ raise KeyError("not support platform type: {}".format(platform))
573
+ cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator)
574
+ # seed
575
+ cfg.seed = seed
576
+
577
+ if save_cfg:
578
+ save_config(cfg, save_path)
579
+ return cfg
DI-engine/ding/config/example/A2C/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ from . import gym_bipedalwalker_v3
3
+ from . import gym_lunarlander_v2
4
+
5
+ supported_env_cfg = {
6
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
7
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
8
+ }
9
+
10
+ supported_env_cfg = EasyDict(supported_env_cfg)
11
+
12
+ supported_env = {
13
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
14
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
15
+ }
16
+
17
+ supported_env = EasyDict(supported_env)
DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='Bipedalwalker-v3-A2C',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='BipedalWalker-v3',
9
+ collector_env_num=8,
10
+ evaluator_env_num=8,
11
+ n_evaluator_episode=8,
12
+ act_scale=True,
13
+ rew_clip=True,
14
+ ),
15
+ policy=dict(
16
+ cuda=True,
17
+ action_space='continuous',
18
+ model=dict(
19
+ action_space='continuous',
20
+ obs_shape=24,
21
+ action_shape=4,
22
+ ),
23
+ learn=dict(
24
+ batch_size=64,
25
+ learning_rate=0.0003,
26
+ value_weight=0.7,
27
+ entropy_weight=0.0005,
28
+ discount_factor=0.99,
29
+ adv_norm=True,
30
+ ),
31
+ collect=dict(
32
+ n_sample=64,
33
+ discount_factor=0.99,
34
+ ),
35
+ ),
36
+ wandb_logger=dict(
37
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
38
+ ),
39
+ )
40
+
41
+ cfg = EasyDict(cfg)
42
+
43
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='LunarLander-v2-A2C',
6
+ env=dict(
7
+ collector_env_num=8,
8
+ evaluator_env_num=8,
9
+ env_id='LunarLander-v2',
10
+ n_evaluator_episode=8,
11
+ stop_value=260,
12
+ ),
13
+ policy=dict(
14
+ cuda=True,
15
+ model=dict(
16
+ obs_shape=8,
17
+ action_shape=4,
18
+ ),
19
+ learn=dict(
20
+ batch_size=64,
21
+ learning_rate=3e-4,
22
+ entropy_weight=0.001,
23
+ adv_norm=True,
24
+ ),
25
+ collect=dict(
26
+ n_sample=64,
27
+ discount_factor=0.99,
28
+ gae_lambda=0.95,
29
+ ),
30
+ ),
31
+ wandb_logger=dict(
32
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
33
+ ),
34
+ )
35
+
36
+ cfg = EasyDict(cfg)
37
+
38
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/C51/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ from . import gym_lunarlander_v2
3
+ from . import gym_pongnoframeskip_v4
4
+ from . import gym_qbertnoframeskip_v4
5
+ from . import gym_spaceInvadersnoframeskip_v4
6
+
7
+ supported_env_cfg = {
8
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
9
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12
+ }
13
+
14
+ supported_env_cfg = EasyDict(supported_env_cfg)
15
+
16
+ supported_env = {
17
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21
+ }
22
+
23
+ supported_env = EasyDict(supported_env)
DI-engine/ding/config/example/C51/gym_lunarlander_v2.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='lunarlander_c51',
6
+ seed=0,
7
+ env=dict(
8
+ collector_env_num=8,
9
+ evaluator_env_num=8,
10
+ env_id='LunarLander-v2',
11
+ n_evaluator_episode=8,
12
+ stop_value=260,
13
+ ),
14
+ policy=dict(
15
+ cuda=False,
16
+ model=dict(
17
+ obs_shape=8,
18
+ action_shape=4,
19
+ encoder_hidden_size_list=[512, 64],
20
+ v_min=-30,
21
+ v_max=30,
22
+ n_atom=51,
23
+ ),
24
+ discount_factor=0.99,
25
+ nstep=3,
26
+ learn=dict(
27
+ update_per_collect=10,
28
+ batch_size=64,
29
+ learning_rate=0.001,
30
+ target_update_freq=100,
31
+ ),
32
+ collect=dict(
33
+ n_sample=64,
34
+ unroll_len=1,
35
+ ),
36
+ other=dict(
37
+ eps=dict(
38
+ type='exp',
39
+ start=0.95,
40
+ end=0.1,
41
+ decay=50000,
42
+ ), replay_buffer=dict(replay_buffer_size=100000, )
43
+ ),
44
+ ),
45
+ wandb_logger=dict(
46
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
47
+ ),
48
+ )
49
+
50
+ cfg = EasyDict(cfg)
51
+
52
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='PongNoFrameskip-v4-C51',
6
+ seed=0,
7
+ env=dict(
8
+ collector_env_num=8,
9
+ evaluator_env_num=8,
10
+ n_evaluator_episode=8,
11
+ stop_value=30,
12
+ env_id='PongNoFrameskip-v4',
13
+ frame_stack=4,
14
+ env_wrapper='atari_default',
15
+ ),
16
+ policy=dict(
17
+ cuda=True,
18
+ priority=False,
19
+ model=dict(
20
+ obs_shape=[4, 84, 84],
21
+ action_shape=6,
22
+ encoder_hidden_size_list=[128, 128, 512],
23
+ v_min=-10,
24
+ v_max=10,
25
+ n_atom=51,
26
+ ),
27
+ nstep=3,
28
+ discount_factor=0.99,
29
+ learn=dict(
30
+ update_per_collect=10,
31
+ batch_size=32,
32
+ learning_rate=0.0001,
33
+ target_update_freq=500,
34
+ ),
35
+ collect=dict(n_sample=100, ),
36
+ eval=dict(evaluator=dict(eval_freq=4000, )),
37
+ other=dict(
38
+ eps=dict(
39
+ type='exp',
40
+ start=1.,
41
+ end=0.05,
42
+ decay=250000,
43
+ ),
44
+ replay_buffer=dict(replay_buffer_size=100000, ),
45
+ ),
46
+ ),
47
+ wandb_logger=dict(
48
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
49
+ ),
50
+ )
51
+
52
+ cfg = EasyDict(cfg)
53
+
54
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='QbertNoFrameskip-v4-C51',
6
+ seed=0,
7
+ env=dict(
8
+ collector_env_num=8,
9
+ evaluator_env_num=8,
10
+ n_evaluator_episode=8,
11
+ stop_value=30000,
12
+ env_id='QbertNoFrameskip-v4',
13
+ frame_stack=4,
14
+ env_wrapper='atari_default',
15
+ ),
16
+ policy=dict(
17
+ cuda=True,
18
+ priority=True,
19
+ model=dict(
20
+ obs_shape=[4, 84, 84],
21
+ action_shape=6,
22
+ encoder_hidden_size_list=[128, 128, 512],
23
+ v_min=-10,
24
+ v_max=10,
25
+ n_atom=51,
26
+ ),
27
+ nstep=3,
28
+ discount_factor=0.99,
29
+ learn=dict(
30
+ update_per_collect=10,
31
+ batch_size=32,
32
+ learning_rate=0.0001,
33
+ target_update_freq=500,
34
+ ),
35
+ collect=dict(n_sample=100, ),
36
+ eval=dict(evaluator=dict(eval_freq=4000, )),
37
+ other=dict(
38
+ eps=dict(
39
+ type='exp',
40
+ start=1.,
41
+ end=0.05,
42
+ decay=1000000,
43
+ ),
44
+ replay_buffer=dict(replay_buffer_size=400000, ),
45
+ ),
46
+ ),
47
+ wandb_logger=dict(
48
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
49
+ ),
50
+ )
51
+
52
+ cfg = EasyDict(cfg)
53
+
54
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='SpaceInvadersNoFrameskip-v4-C51',
6
+ seed=0,
7
+ env=dict(
8
+ collector_env_num=8,
9
+ evaluator_env_num=8,
10
+ n_evaluator_episode=8,
11
+ stop_value=10000000000,
12
+ env_id='SpaceInvadersNoFrameskip-v4',
13
+ frame_stack=4,
14
+ env_wrapper='atari_default',
15
+ ),
16
+ policy=dict(
17
+ cuda=True,
18
+ priority=False,
19
+ model=dict(
20
+ obs_shape=[4, 84, 84],
21
+ action_shape=6,
22
+ encoder_hidden_size_list=[128, 128, 512],
23
+ v_min=-10,
24
+ v_max=10,
25
+ n_atom=51,
26
+ ),
27
+ nstep=3,
28
+ discount_factor=0.99,
29
+ learn=dict(
30
+ update_per_collect=10,
31
+ batch_size=32,
32
+ learning_rate=0.0001,
33
+ target_update_freq=500,
34
+ ),
35
+ collect=dict(n_sample=100, ),
36
+ eval=dict(evaluator=dict(eval_freq=4000, )),
37
+ other=dict(
38
+ eps=dict(
39
+ type='exp',
40
+ start=1.,
41
+ end=0.05,
42
+ decay=1000000,
43
+ ),
44
+ replay_buffer=dict(replay_buffer_size=400000, ),
45
+ ),
46
+ ),
47
+ wandb_logger=dict(
48
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
49
+ ),
50
+ )
51
+
52
+ cfg = EasyDict(cfg)
53
+
54
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DDPG/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ from . import gym_bipedalwalker_v3
3
+ from . import gym_halfcheetah_v3
4
+ from . import gym_hopper_v3
5
+ from . import gym_lunarlandercontinuous_v2
6
+ from . import gym_pendulum_v1
7
+ from . import gym_walker2d_v3
8
+
9
+ supported_env_cfg = {
10
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
11
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.cfg,
12
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.cfg,
13
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.cfg,
14
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
15
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.cfg,
16
+ }
17
+
18
+ supported_env_cfg = EasyDict(supported_env_cfg)
19
+
20
+ supported_env = {
21
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
22
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.env,
23
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.env,
24
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.env,
25
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
26
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.env,
27
+ }
28
+
29
+ supported_env = EasyDict(supported_env)
DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='Bipedalwalker-v3-DDPG',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='BipedalWalker-v3',
9
+ collector_env_num=8,
10
+ evaluator_env_num=5,
11
+ n_evaluator_episode=5,
12
+ act_scale=True,
13
+ rew_clip=True,
14
+ ),
15
+ policy=dict(
16
+ cuda=True,
17
+ random_collect_size=10000,
18
+ model=dict(
19
+ obs_shape=24,
20
+ action_shape=4,
21
+ twin_critic=False,
22
+ action_space='regression',
23
+ actor_head_hidden_size=400,
24
+ critic_head_hidden_size=400,
25
+ ),
26
+ learn=dict(
27
+ update_per_collect=64,
28
+ batch_size=256,
29
+ learning_rate_actor=0.0003,
30
+ learning_rate_critic=0.0003,
31
+ target_theta=0.005,
32
+ discount_factor=0.99,
33
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
34
+ ),
35
+ collect=dict(n_sample=64, ),
36
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
37
+ ),
38
+ wandb_logger=dict(
39
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
40
+ ),
41
+ )
42
+
43
+ cfg = EasyDict(cfg)
44
+
45
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='HalfCheetah-v3-DDPG',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='HalfCheetah-v3',
9
+ norm_obs=dict(use_norm=False, ),
10
+ norm_reward=dict(use_norm=False, ),
11
+ collector_env_num=1,
12
+ evaluator_env_num=8,
13
+ n_evaluator_episode=8,
14
+ stop_value=11000,
15
+ env_wrapper='mujoco_default',
16
+ ),
17
+ policy=dict(
18
+ cuda=True,
19
+ random_collect_size=25000,
20
+ model=dict(
21
+ obs_shape=17,
22
+ action_shape=6,
23
+ twin_critic=False,
24
+ actor_head_hidden_size=256,
25
+ critic_head_hidden_size=256,
26
+ action_space='regression',
27
+ ),
28
+ learn=dict(
29
+ update_per_collect=1,
30
+ batch_size=256,
31
+ learning_rate_actor=1e-3,
32
+ learning_rate_critic=1e-3,
33
+ ignore_done=True,
34
+ target_theta=0.005,
35
+ discount_factor=0.99,
36
+ actor_update_freq=1,
37
+ noise=False,
38
+ ),
39
+ collect=dict(
40
+ n_sample=1,
41
+ unroll_len=1,
42
+ noise_sigma=0.1,
43
+ ),
44
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
45
+ ),
46
+ wandb_logger=dict(
47
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48
+ ),
49
+ )
50
+
51
+ cfg = EasyDict(cfg)
52
+
53
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DDPG/gym_hopper_v3.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='Hopper-v3-DDPG',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='Hopper-v3',
9
+ norm_obs=dict(use_norm=False, ),
10
+ norm_reward=dict(use_norm=False, ),
11
+ collector_env_num=1,
12
+ evaluator_env_num=8,
13
+ n_evaluator_episode=8,
14
+ stop_value=6000,
15
+ env_wrapper='mujoco_default',
16
+ ),
17
+ policy=dict(
18
+ cuda=True,
19
+ random_collect_size=25000,
20
+ model=dict(
21
+ obs_shape=11,
22
+ action_shape=3,
23
+ twin_critic=False,
24
+ actor_head_hidden_size=256,
25
+ critic_head_hidden_size=256,
26
+ action_space='regression',
27
+ ),
28
+ learn=dict(
29
+ update_per_collect=1,
30
+ batch_size=256,
31
+ learning_rate_actor=1e-3,
32
+ learning_rate_critic=1e-3,
33
+ ignore_done=False,
34
+ target_theta=0.005,
35
+ discount_factor=0.99,
36
+ actor_update_freq=1,
37
+ noise=False,
38
+ ),
39
+ collect=dict(
40
+ n_sample=1,
41
+ unroll_len=1,
42
+ noise_sigma=0.1,
43
+ ),
44
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
45
+ ),
46
+ wandb_logger=dict(
47
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48
+ ),
49
+ )
50
+
51
+ cfg = EasyDict(cfg)
52
+
53
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ from functools import partial
3
+ import ding.envs.gym_env
4
+
5
+ cfg = dict(
6
+ exp_name='LunarLanderContinuous-V2-DDPG',
7
+ seed=0,
8
+ env=dict(
9
+ env_id='LunarLanderContinuous-v2',
10
+ collector_env_num=8,
11
+ evaluator_env_num=8,
12
+ n_evaluator_episode=8,
13
+ stop_value=260,
14
+ act_scale=True,
15
+ ),
16
+ policy=dict(
17
+ cuda=True,
18
+ random_collect_size=0,
19
+ model=dict(
20
+ obs_shape=8,
21
+ action_shape=2,
22
+ twin_critic=True,
23
+ action_space='regression',
24
+ ),
25
+ learn=dict(
26
+ update_per_collect=2,
27
+ batch_size=128,
28
+ learning_rate_actor=0.001,
29
+ learning_rate_critic=0.001,
30
+ ignore_done=False, # TODO(pu)
31
+ # (int) When critic network updates once, how many times will actor network update.
32
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
33
+ # Default 1 for DDPG, 2 for TD3.
34
+ actor_update_freq=1,
35
+ # (bool) Whether to add noise on target network's action.
36
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
37
+ # Default True for TD3, False for DDPG.
38
+ noise=False,
39
+ noise_sigma=0.1,
40
+ noise_range=dict(
41
+ min=-0.5,
42
+ max=0.5,
43
+ ),
44
+ ),
45
+ collect=dict(
46
+ n_sample=48,
47
+ noise_sigma=0.1,
48
+ collector=dict(collect_print_freq=1000, ),
49
+ ),
50
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
51
+ other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
52
+ ),
53
+ wandb_logger=dict(
54
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
55
+ ),
56
+ )
57
+
58
+ cfg = EasyDict(cfg)
59
+
60
+ env = partial(ding.envs.gym_env.env, continuous=True)
DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='Pendulum-v1-DDPG',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='Pendulum-v1',
9
+ collector_env_num=8,
10
+ evaluator_env_num=5,
11
+ n_evaluator_episode=5,
12
+ stop_value=-250,
13
+ act_scale=True,
14
+ ),
15
+ policy=dict(
16
+ cuda=False,
17
+ priority=False,
18
+ random_collect_size=800,
19
+ model=dict(
20
+ obs_shape=3,
21
+ action_shape=1,
22
+ twin_critic=False,
23
+ action_space='regression',
24
+ ),
25
+ learn=dict(
26
+ update_per_collect=2,
27
+ batch_size=128,
28
+ learning_rate_actor=0.001,
29
+ learning_rate_critic=0.001,
30
+ ignore_done=True,
31
+ actor_update_freq=1,
32
+ noise=False,
33
+ ),
34
+ collect=dict(
35
+ n_sample=48,
36
+ noise_sigma=0.1,
37
+ collector=dict(collect_print_freq=1000, ),
38
+ ),
39
+ eval=dict(evaluator=dict(eval_freq=100, )),
40
+ other=dict(replay_buffer=dict(
41
+ replay_buffer_size=20000,
42
+ max_use=16,
43
+ ), ),
44
+ ),
45
+ wandb_logger=dict(
46
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
47
+ ),
48
+ )
49
+
50
+ cfg = EasyDict(cfg)
51
+
52
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='Walker2d-v3-DDPG',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='Walker2d-v3',
9
+ norm_obs=dict(use_norm=False, ),
10
+ norm_reward=dict(use_norm=False, ),
11
+ collector_env_num=1,
12
+ evaluator_env_num=8,
13
+ n_evaluator_episode=8,
14
+ stop_value=6000,
15
+ env_wrapper='mujoco_default',
16
+ ),
17
+ policy=dict(
18
+ cuda=True,
19
+ random_collect_size=25000,
20
+ model=dict(
21
+ obs_shape=17,
22
+ action_shape=6,
23
+ twin_critic=False,
24
+ actor_head_hidden_size=256,
25
+ critic_head_hidden_size=256,
26
+ action_space='regression',
27
+ ),
28
+ learn=dict(
29
+ update_per_collect=1,
30
+ batch_size=256,
31
+ learning_rate_actor=1e-3,
32
+ learning_rate_critic=1e-3,
33
+ ignore_done=False,
34
+ target_theta=0.005,
35
+ discount_factor=0.99,
36
+ actor_update_freq=1,
37
+ noise=False,
38
+ ),
39
+ collect=dict(
40
+ n_sample=1,
41
+ unroll_len=1,
42
+ noise_sigma=0.1,
43
+ ),
44
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
45
+ ),
46
+ wandb_logger=dict(
47
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48
+ ),
49
+ )
50
+
51
+ cfg = EasyDict(cfg)
52
+
53
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DQN/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ from . import gym_lunarlander_v2
3
+ from . import gym_pongnoframeskip_v4
4
+ from . import gym_qbertnoframeskip_v4
5
+ from . import gym_spaceInvadersnoframeskip_v4
6
+
7
+ supported_env_cfg = {
8
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
9
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
10
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
11
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
12
+ }
13
+
14
+ supported_env_cfg = EasyDict(supported_env_cfg)
15
+
16
+ supported_env = {
17
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
18
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
19
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
20
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
21
+ }
22
+
23
+ supported_env = EasyDict(supported_env)
DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='LunarLander-v2-DQN',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='LunarLander-v2',
9
+ collector_env_num=8,
10
+ evaluator_env_num=8,
11
+ n_evaluator_episode=8,
12
+ stop_value=260,
13
+ ),
14
+ policy=dict(
15
+ cuda=True,
16
+ random_collect_size=25000,
17
+ discount_factor=0.99,
18
+ nstep=3,
19
+ learn=dict(
20
+ update_per_collect=10,
21
+ batch_size=64,
22
+ learning_rate=0.001,
23
+ # Frequency of target network update.
24
+ target_update_freq=100,
25
+ ),
26
+ model=dict(
27
+ obs_shape=8,
28
+ action_shape=4,
29
+ encoder_hidden_size_list=[512, 64],
30
+ # Whether to use dueling head.
31
+ dueling=True,
32
+ ),
33
+ collect=dict(
34
+ n_sample=64,
35
+ unroll_len=1,
36
+ ),
37
+ other=dict(
38
+ eps=dict(
39
+ type='exp',
40
+ start=0.95,
41
+ end=0.1,
42
+ decay=50000,
43
+ ), replay_buffer=dict(replay_buffer_size=100000, )
44
+ ),
45
+ ),
46
+ wandb_logger=dict(
47
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
48
+ ),
49
+ )
50
+
51
+ cfg = EasyDict(cfg)
52
+
53
+ env = ding.envs.gym_env.env
DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import ding.envs.gym_env
3
+
4
+ cfg = dict(
5
+ exp_name='PongNoFrameskip-v4-DQN',
6
+ seed=0,
7
+ env=dict(
8
+ env_id='PongNoFrameskip-v4',
9
+ collector_env_num=8,
10
+ evaluator_env_num=8,
11
+ n_evaluator_episode=8,
12
+ stop_value=30,
13
+ fram_stack=4,
14
+ env_wrapper='atari_default',
15
+ ),
16
+ policy=dict(
17
+ cuda=True,
18
+ priority=False,
19
+ discount_factor=0.99,
20
+ nstep=3,
21
+ learn=dict(
22
+ update_per_collect=10,
23
+ batch_size=32,
24
+ learning_rate=0.0001,
25
+ # Frequency of target network update.
26
+ target_update_freq=500,
27
+ ),
28
+ model=dict(
29
+ obs_shape=[4, 84, 84],
30
+ action_shape=6,
31
+ encoder_hidden_size_list=[128, 128, 512],
32
+ ),
33
+ collect=dict(n_sample=96, ),
34
+ other=dict(
35
+ eps=dict(
36
+ type='exp',
37
+ start=1.,
38
+ end=0.05,
39
+ decay=250000,
40
+ ), replay_buffer=dict(replay_buffer_size=100000, )
41
+ ),
42
+ ),
43
+ wandb_logger=dict(
44
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
45
+ ),
46
+ )
47
+
48
+ cfg = EasyDict(cfg)
49
+
50
+ env = ding.envs.gym_env.env