|
import numpy as np |
|
import dataclasses |
|
import treetensor.torch as ttorch |
|
from typing import Union, Dict, List |
|
|
|
|
|
@dataclasses.dataclass |
|
class Context: |
|
""" |
|
Overview: |
|
Context is an object that pass contextual data between middlewares, whose life cycle |
|
is only one training iteration. It is a dict that reflect itself, so you can set |
|
any properties as you wish. |
|
Note that the initial value of the property must be equal to False. |
|
""" |
|
_kept_keys: set = dataclasses.field(default_factory=set) |
|
total_step: int = 0 |
|
|
|
def renew(self) -> 'Context': |
|
""" |
|
Overview: |
|
Renew context from self, add total_step and shift kept properties to the new instance. |
|
""" |
|
total_step = self.total_step |
|
ctx = type(self)() |
|
for key in self._kept_keys: |
|
if self.has_attr(key): |
|
setattr(ctx, key, getattr(self, key)) |
|
ctx.total_step = total_step + 1 |
|
return ctx |
|
|
|
def keep(self, *keys: str) -> None: |
|
""" |
|
Overview: |
|
Keep this key/keys until next iteration. |
|
""" |
|
for key in keys: |
|
self._kept_keys.add(key) |
|
|
|
def has_attr(self, key): |
|
return hasattr(self, key) |
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
class OnlineRLContext(Context): |
|
|
|
|
|
total_step: int = 0 |
|
env_step: int = 0 |
|
env_episode: int = 0 |
|
train_iter: int = 0 |
|
train_data: Union[Dict, List] = None |
|
train_output: Union[Dict, List[Dict]] = None |
|
|
|
collect_kwargs: Dict = dataclasses.field(default_factory=dict) |
|
obs: ttorch.Tensor = None |
|
action: List = None |
|
inference_output: Dict[int, Dict] = None |
|
trajectories: List = None |
|
episodes: List = None |
|
trajectory_end_idx: List = dataclasses.field(default_factory=list) |
|
action: Dict = None |
|
inference_output: Dict = None |
|
|
|
eval_value: float = -np.inf |
|
last_eval_iter: int = -1 |
|
last_eval_value: int = -np.inf |
|
eval_output: List = dataclasses.field(default_factory=dict) |
|
|
|
wandb_url: str = "" |
|
|
|
def __post_init__(self): |
|
|
|
|
|
|
|
self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url') |
|
|
|
|
|
@dataclasses.dataclass |
|
class OfflineRLContext(Context): |
|
|
|
|
|
total_step: int = 0 |
|
trained_env_step: int = 0 |
|
train_epoch: int = 0 |
|
train_iter: int = 0 |
|
train_data: Union[Dict, List] = None |
|
train_output: Union[Dict, List[Dict]] = None |
|
|
|
eval_value: float = -np.inf |
|
last_eval_iter: int = -1 |
|
last_eval_value: int = -np.inf |
|
eval_output: List = dataclasses.field(default_factory=dict) |
|
|
|
wandb_url: str = "" |
|
|
|
def __post_init__(self): |
|
|
|
|
|
|
|
self.keep('trained_env_step', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url') |
|
|