gomoku / DI-engine /dizoo /atari /entry /atari_dt_main.py
zjowowen's picture
init space
079c32c
raw
history blame
2.57 kB
import torch.nn as nn
import torch.distributed as dist
from ditk import logging
from ding.model import DecisionTransformer
from ding.policy import DTPolicy
from ding.envs import SubprocessEnvManagerV2
from ding.envs import AllinObsWrapper
from ding.data import create_dataset
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OfflineRLContext
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, \
OfflineMemoryDataFetcher
from ding.utils import set_pkg_seed, DDPContext, to_ddp_config
from dizoo.atari.envs import AtariEnv
from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config
def main():
# If you don't have offline data, you need to prepare if first and set the data_path in config
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
logging.getLogger().setLevel(logging.INFO)
with DDPContext():
cmain_config = to_ddp_config(main_config)
cfg = compile_config(cmain_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OfflineRLContext()):
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
dataset = create_dataset(cfg)
cfg.policy.model.max_timestep = dataset.get_max_timestep()
state_encoder = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), nn.Conv2d(32, 64, 4, stride=2, padding=0),
nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), nn.Flatten(),
nn.Linear(3136, cfg.policy.model.h_dim), nn.Tanh()
)
model = DecisionTransformer(**cfg.policy.model, state_encoder=state_encoder)
# model.parallelize()
policy = DTPolicy(cfg.policy, model=model)
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(OfflineMemoryDataFetcher(cfg, dataset))
task.use(trainer(cfg, policy.learn_mode))
task.use(termination_checker(max_train_iter=3e4))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(offline_logger())
task.run()
if __name__ == "__main__":
main()