Spaces:
Runtime error
Runtime error
import os | |
from typing import List | |
from trlx.data.configs import TRLConfig | |
def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]: | |
"""Returns all sub-directories of `dir` named `configs`.""" | |
config_dirs = [] | |
for root, dirs, _ in os.walk(dir): | |
for d in dirs: | |
if d == config_dir_name: | |
config_dirs.append(os.path.join(root, d)) | |
return config_dirs | |
def _get_yaml_filepaths(dir: str) -> List[str]: | |
"""Returns a list of `yml` filepaths in `dir`.""" | |
filepaths = [] | |
for file in os.listdir(dir): | |
if file.endswith(".yml"): | |
filepaths.append(os.path.join(dir, file)) | |
return filepaths | |
def test_repo_trl_configs(): | |
"""Tests to ensure all default configs in the repository are valid.""" | |
config_dirs = ["configs", *_get_config_dirs("examples")] | |
config_files = sum(map(_get_yaml_filepaths, config_dirs), []) # sum for flat-map behavior | |
for file in config_files: | |
assert os.path.isfile(file), f"Config file {file} does not exist." | |
assert file.endswith(".yml"), f"Config file {file} is not a yaml file." | |
try: | |
config = TRLConfig.load_yaml(file) | |
assert ( | |
config.train.entity_name is None | |
), f"Unexpected entity name in config file `{file}`. Remove before pushing to repo." | |
except Exception as e: | |
assert False, f"Failed to load config file `{file}` with error `{e}`" | |