File size: 747 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import copy

import torch

from ding.envs.common import EnvElementRunner
from ding.envs.env.base_env import BaseEnv
from .gfootball_reward import GfootballReward


class GfootballRewardRunner(EnvElementRunner):

    def _init(self, cfg, *args, **kwargs) -> None:
        # set self._core and other state variable
        self._core = GfootballReward(cfg)
        self._cum_reward = 0.0

    def get(self, engine: BaseEnv) -> torch.tensor:
        ret = copy.deepcopy(engine._reward_of_action)
        self._cum_reward += ret
        return self._core._to_agent_processor(ret)

    def reset(self) -> None:
        self._cum_reward = 0.0

    @property
    def cum_reward(self) -> torch.tensor:
        return torch.FloatTensor([self._cum_reward])