nsthorat's picture
Push
2ad28d6
"""Configurations for a dataset run."""
import json
import pathlib
from typing import TYPE_CHECKING, Any, Optional, Union
import yaml
if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
from pydantic import BaseModel, Extra, ValidationError, validator
from .schema import Path, PathTuple, normalize_path
from .signal import Signal, TextEmbeddingSignal, get_signal_by_type, resolve_signal
from .sources.source import Source
from .sources.source_registry import resolve_source
CONFIG_FILENAME = 'config.yml'
def _serializable_path(path: PathTuple) -> Union[str, list]:
if len(path) == 1:
return path[0]
return list(path)
class SignalConfig(BaseModel):
"""Configures a signal on a source path."""
path: PathTuple
signal: Signal
class Config:
extra = Extra.forbid
@validator('path', pre=True)
def parse_path(cls, path: Path) -> PathTuple:
"""Parse a path."""
return normalize_path(path)
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> dict[str, Any]:
"""Override the default dict method to simplify the path tuples.
This is required to remove the python-specific tuple dump in the yaml file.
"""
res = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
res['path'] = _serializable_path(res['path'])
return res
class EmbeddingConfig(BaseModel):
"""Configures an embedding on a source path."""
path: PathTuple
embedding: str
class Config:
extra = Extra.forbid
def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> dict[str, Any]:
"""Override the default dict method to simplify the path tuples.
This is required to remove the python-specific tuple dump in the yaml file.
"""
res = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
res['path'] = _serializable_path(res['path'])
return res
@validator('path', pre=True)
def parse_path(cls, path: Path) -> PathTuple:
"""Parse a path."""
return normalize_path(path)
@validator('embedding', pre=True)
def validate_embedding(cls, embedding: str) -> str:
"""Validate the embedding is registered."""
get_signal_by_type(embedding, TextEmbeddingSignal)
return embedding
class DatasetUISettings(BaseModel):
"""The UI persistent settings for a dataset."""
media_paths: list[PathTuple] = []
markdown_paths: list[PathTuple] = []
class Config:
extra = Extra.forbid
@validator('media_paths', pre=True)
def parse_media_paths(cls, media_paths: list) -> list:
"""Parse a path, ensuring it is a tuple."""
return [normalize_path(path) for path in media_paths]
def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> dict[str, Any]:
"""Override the default dict method to simplify the path tuples.
This is required to remove the python-specific tuple dump in the yaml file.
"""
# TODO(nsthorat): Migrate this to @field_serializer when we upgrade to pydantic v2.
res = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
if 'media_paths' in res:
res['media_paths'] = [_serializable_path(path) for path in res['media_paths']]
if 'markdown_paths' in res:
res['markdown_paths'] = [_serializable_path(path) for path in res['markdown_paths']]
return res
class DatasetSettings(BaseModel):
"""The persistent settings for a dataset."""
ui: Optional[DatasetUISettings] = None
preferred_embedding: Optional[str] = None
class Config:
extra = Extra.forbid
class DatasetConfig(BaseModel):
"""Configures a dataset with a source and transformations."""
# The namespace and name of the dataset.
namespace: str
name: str
# Tags to organize datasets.
tags: list[str] = []
# The source configuration.
source: Source
# Model configuration: embeddings and signals on paths.
embeddings: list[EmbeddingConfig] = []
# When defined, uses this list of signals instead of running all signals.
signals: list[SignalConfig] = []
# Dataset settings, default embeddings and UI settings like media paths.
settings: Optional[DatasetSettings] = None
class Config:
extra = Extra.forbid
@validator('source', pre=True)
def parse_source(cls, source: dict) -> Source:
"""Parse a source to its specific subclass instance."""
return resolve_source(source)
class Config(BaseModel):
"""Configures a set of datasets for a lilac instance."""
datasets: list[DatasetConfig]
# When defined, uses this list of signals to run over every dataset, over all media paths, unless
# signals is overridden by a specific dataset.
signals: list[Signal] = []
# A list of embeddings to compute the model caches for, for all concepts.
concept_model_cache_embeddings: list[str] = []
class Config:
extra = Extra.forbid
@validator('signals', pre=True)
def parse_signal(cls, signals: list[dict]) -> list[Signal]:
"""Parse alist of signals to their specific subclass instances."""
return [resolve_signal(signal) for signal in signals]
def read_config(config_path: str) -> Config:
"""Reads a config file.
The config file can either be a `Config` or a `DatasetConfig`.
The result is always a `Config` object. If the input is a `DatasetConfig`, the config will just
contain a single dataset.
"""
config_ext = pathlib.Path(config_path).suffix
if config_ext in ['.yml', '.yaml']:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
elif config_ext in ['.json']:
with open(config_path, 'r') as f:
config_dict = json.load(f)
else:
raise ValueError(f'Unsupported config file extension: {config_ext}')
config: Optional[Config] = None
is_config = True
try:
config = Config(**config_dict)
except ValidationError:
is_config = False
if not is_config:
try:
dataset_config = DatasetConfig(**config_dict)
config = Config(datasets=[dataset_config])
except ValidationError as error:
raise ValidationError(
'Config is not a valid `Config` or `DatasetConfig`', model=DatasetConfig) from error
assert config is not None
return config