from easydict import EasyDict from typing import Callable def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: """ Overview: Get the rollout length scheduler that adapts rollout length based\ on the current environment steps. Returns: - scheduler (:obj:`Callble`): The function that takes envstep and\ return the current rollout length. """ if cfg.type == 'linear': x0 = cfg.rollout_start_step x1 = cfg.rollout_end_step y0 = cfg.rollout_length_min y1 = cfg.rollout_length_max w = (y1 - y0) / (x1 - x0) b = y0 return lambda x: int(min(max(w * (x - x0) + b, y0), y1)) elif cfg.type == 'constant': return lambda x: cfg.rollout_length else: raise KeyError("not implemented key: {}".format(cfg.type))