Spaces:
Runtime error
Runtime error
"""Configurations for a dataset run.""" | |
import json | |
import pathlib | |
from typing import TYPE_CHECKING, Any, Optional, Union | |
import yaml | |
if TYPE_CHECKING: | |
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny | |
from pydantic import BaseModel, Extra, ValidationError, validator | |
from .schema import Path, PathTuple, normalize_path | |
from .signal import Signal, TextEmbeddingSignal, get_signal_by_type, resolve_signal | |
from .sources.source import Source | |
from .sources.source_registry import resolve_source | |
CONFIG_FILENAME = 'config.yml' | |
def _serializable_path(path: PathTuple) -> Union[str, list]: | |
if len(path) == 1: | |
return path[0] | |
return list(path) | |
class SignalConfig(BaseModel): | |
"""Configures a signal on a source path.""" | |
path: PathTuple | |
signal: Signal | |
class Config: | |
extra = Extra.forbid | |
def parse_path(cls, path: Path) -> PathTuple: | |
"""Parse a path.""" | |
return normalize_path(path) | |
def parse_signal(cls, signal: dict) -> Signal: | |
"""Parse a signal to its specific subclass instance.""" | |
return resolve_signal(signal) | |
def dict( | |
self, | |
*, | |
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
by_alias: bool = False, | |
skip_defaults: Optional[bool] = None, | |
exclude_unset: bool = False, | |
exclude_defaults: bool = False, | |
exclude_none: bool = False, | |
) -> dict[str, Any]: | |
"""Override the default dict method to simplify the path tuples. | |
This is required to remove the python-specific tuple dump in the yaml file. | |
""" | |
res = super().dict( | |
include=include, | |
exclude=exclude, | |
by_alias=by_alias, | |
skip_defaults=skip_defaults, | |
exclude_unset=exclude_unset, | |
exclude_defaults=exclude_defaults, | |
exclude_none=exclude_none) | |
res['path'] = _serializable_path(res['path']) | |
return res | |
class EmbeddingConfig(BaseModel): | |
"""Configures an embedding on a source path.""" | |
path: PathTuple | |
embedding: str | |
class Config: | |
extra = Extra.forbid | |
def dict( | |
self, | |
*, | |
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
by_alias: bool = False, | |
skip_defaults: Optional[bool] = None, | |
exclude_unset: bool = False, | |
exclude_defaults: bool = False, | |
exclude_none: bool = False, | |
) -> dict[str, Any]: | |
"""Override the default dict method to simplify the path tuples. | |
This is required to remove the python-specific tuple dump in the yaml file. | |
""" | |
res = super().dict( | |
include=include, | |
exclude=exclude, | |
by_alias=by_alias, | |
skip_defaults=skip_defaults, | |
exclude_unset=exclude_unset, | |
exclude_defaults=exclude_defaults, | |
exclude_none=exclude_none) | |
res['path'] = _serializable_path(res['path']) | |
return res | |
def parse_path(cls, path: Path) -> PathTuple: | |
"""Parse a path.""" | |
return normalize_path(path) | |
def validate_embedding(cls, embedding: str) -> str: | |
"""Validate the embedding is registered.""" | |
get_signal_by_type(embedding, TextEmbeddingSignal) | |
return embedding | |
class DatasetUISettings(BaseModel): | |
"""The UI persistent settings for a dataset.""" | |
media_paths: list[PathTuple] = [] | |
markdown_paths: list[PathTuple] = [] | |
class Config: | |
extra = Extra.forbid | |
def parse_media_paths(cls, media_paths: list) -> list: | |
"""Parse a path, ensuring it is a tuple.""" | |
return [normalize_path(path) for path in media_paths] | |
def dict( | |
self, | |
*, | |
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, | |
by_alias: bool = False, | |
skip_defaults: Optional[bool] = None, | |
exclude_unset: bool = False, | |
exclude_defaults: bool = False, | |
exclude_none: bool = False, | |
) -> dict[str, Any]: | |
"""Override the default dict method to simplify the path tuples. | |
This is required to remove the python-specific tuple dump in the yaml file. | |
""" | |
# TODO(nsthorat): Migrate this to @field_serializer when we upgrade to pydantic v2. | |
res = super().dict( | |
include=include, | |
exclude=exclude, | |
by_alias=by_alias, | |
skip_defaults=skip_defaults, | |
exclude_unset=exclude_unset, | |
exclude_defaults=exclude_defaults, | |
exclude_none=exclude_none) | |
if 'media_paths' in res: | |
res['media_paths'] = [_serializable_path(path) for path in res['media_paths']] | |
if 'markdown_paths' in res: | |
res['markdown_paths'] = [_serializable_path(path) for path in res['markdown_paths']] | |
return res | |
class DatasetSettings(BaseModel): | |
"""The persistent settings for a dataset.""" | |
ui: Optional[DatasetUISettings] = None | |
preferred_embedding: Optional[str] = None | |
class Config: | |
extra = Extra.forbid | |
class DatasetConfig(BaseModel): | |
"""Configures a dataset with a source and transformations.""" | |
# The namespace and name of the dataset. | |
namespace: str | |
name: str | |
# Tags to organize datasets. | |
tags: list[str] = [] | |
# The source configuration. | |
source: Source | |
# Model configuration: embeddings and signals on paths. | |
embeddings: list[EmbeddingConfig] = [] | |
# When defined, uses this list of signals instead of running all signals. | |
signals: list[SignalConfig] = [] | |
# Dataset settings, default embeddings and UI settings like media paths. | |
settings: Optional[DatasetSettings] = None | |
class Config: | |
extra = Extra.forbid | |
def parse_source(cls, source: dict) -> Source: | |
"""Parse a source to its specific subclass instance.""" | |
return resolve_source(source) | |
class Config(BaseModel): | |
"""Configures a set of datasets for a lilac instance.""" | |
datasets: list[DatasetConfig] | |
# When defined, uses this list of signals to run over every dataset, over all media paths, unless | |
# signals is overridden by a specific dataset. | |
signals: list[Signal] = [] | |
# A list of embeddings to compute the model caches for, for all concepts. | |
concept_model_cache_embeddings: list[str] = [] | |
class Config: | |
extra = Extra.forbid | |
def parse_signal(cls, signals: list[dict]) -> list[Signal]: | |
"""Parse alist of signals to their specific subclass instances.""" | |
return [resolve_signal(signal) for signal in signals] | |
def read_config(config_path: str) -> Config: | |
"""Reads a config file. | |
The config file can either be a `Config` or a `DatasetConfig`. | |
The result is always a `Config` object. If the input is a `DatasetConfig`, the config will just | |
contain a single dataset. | |
""" | |
config_ext = pathlib.Path(config_path).suffix | |
if config_ext in ['.yml', '.yaml']: | |
with open(config_path, 'r') as f: | |
config_dict = yaml.safe_load(f) | |
elif config_ext in ['.json']: | |
with open(config_path, 'r') as f: | |
config_dict = json.load(f) | |
else: | |
raise ValueError(f'Unsupported config file extension: {config_ext}') | |
config: Optional[Config] = None | |
is_config = True | |
try: | |
config = Config(**config_dict) | |
except ValidationError: | |
is_config = False | |
if not is_config: | |
try: | |
dataset_config = DatasetConfig(**config_dict) | |
config = Config(datasets=[dataset_config]) | |
except ValidationError as error: | |
raise ValidationError( | |
'Config is not a valid `Config` or `DatasetConfig`', model=DatasetConfig) from error | |
assert config is not None | |
return config | |