Spaces:
Runtime error
Runtime error
File size: 8,844 Bytes
bfc0ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 |
"""A script to load a dataset or set of datasets from a config for a Lilac instance.
Usage:
poetry run python -m lilac.load \
--output_dir=demo_data \
--config_path=lilac_hf_space.yml
"""
import gc
import os
import shutil
import click
import dask
import psutil
from distributed import Client
from .concepts.db_concept import DiskConceptDB, DiskConceptModelDB
from .config import EmbeddingConfig, SignalConfig, read_config
from .data.dataset_duckdb import DatasetDuckDB
from .data_loader import process_source
from .db_manager import get_dataset, list_datasets, remove_dataset_from_cache
from .schema import ROWID, PathTuple
from .tasks import TaskManager, TaskStepId
from .utils import DebugTimer, get_datasets_dir
@click.command()
@click.option(
'--output_dir', required=True, type=str, help='The output directory to write files to.')
@click.option(
'--config_path',
required=True,
type=str,
help='The path to a json or yml file describing the configuration. '
'The file contents should be an instance of `lilac.Config` or `lilac.DatasetConfig`.')
@click.option(
'--overwrite',
help='When True, runs all all data from scratch, overwriting existing data. When false, only'
'load new datasets, embeddings, and signals.',
type=bool,
is_flag=True,
default=False)
def load_command(output_dir: str, config_path: str, overwrite: bool) -> None:
"""Run the source loader as a binary."""
load(output_dir, config_path, overwrite)
def load(output_dir: str, config_path: str, overwrite: bool) -> None:
"""Run the source loader as a binary."""
old_data_path = os.environ.get('LILAC_DATA_PATH')
os.environ['LILAC_DATA_PATH'] = output_dir
# Turn off debug logging.
del os.environ['DEBUG']
# Use views to avoid loading duckdb tables into RAM since we aren't query heavy.
os.environ['DUCKDB_USE_VIEWS'] = '1'
config = read_config(config_path)
# Explicitly create a dask client in sync mode.
dask.config.set({'distributed.worker.daemon': False})
total_memory_gb = psutil.virtual_memory().total / (1024**3) * 2 / 3
task_manager = TaskManager(Client(memory_limit=f'{total_memory_gb} GB'))
if overwrite:
shutil.rmtree(get_datasets_dir(output_dir), ignore_errors=True)
existing_datasets = [f'{d.namespace}/{d.dataset_name}' for d in list_datasets(output_dir)]
print()
print('*** Load datasets ***')
if overwrite:
datasets_to_load = config.datasets
else:
datasets_to_load = [
d for d in config.datasets if f'{d.namespace}/{d.name}' not in existing_datasets
]
skipped_datasets = [
d for d in config.datasets if f'{d.namespace}/{d.name}' in existing_datasets
]
print('Skipping loaded datasets:', ', '.join([d.name for d in skipped_datasets]))
with DebugTimer(f'Loading datasets: {", ".join([d.name for d in datasets_to_load])}'):
for d in datasets_to_load:
shutil.rmtree(os.path.join(output_dir, d.name), ignore_errors=True)
task_id = task_manager.task_id(f'Load dataset {d.namespace}/{d.name}')
task_manager.execute(task_id, process_source, output_dir, d, (task_id, 0))
task_manager.wait()
print()
total_num_rows = 0
for d in datasets_to_load:
dataset = DatasetDuckDB(d.namespace, d.name)
num_rows = dataset.select_rows([ROWID], limit=1).total_num_rows
print(f'{d.namespace}/{d.name} loaded with {num_rows:,} rows.')
# Free up RAM.
del dataset
total_num_rows += num_rows
print(f'Done loading {len(datasets_to_load)} datasets with {total_num_rows:,} rows.')
print('*** Dataset settings ***')
for d in config.datasets:
if d.settings:
dataset = DatasetDuckDB(d.namespace, d.name)
dataset.update_settings(d.settings)
print()
print('*** Compute embeddings ***')
with DebugTimer('Loading embeddings'):
for d in config.datasets:
dataset = DatasetDuckDB(d.namespace, d.name)
# If embeddings are explicitly set, use only those.
embeddings = d.embeddings or []
# If embeddings are not explicitly set, use the media paths and preferred embedding from
# settings.
if not embeddings:
if d.settings and d.settings.ui:
for path in d.settings.ui.media_paths or []:
if d.settings.preferred_embedding:
embeddings.append(
EmbeddingConfig(path=path, embedding=d.settings.preferred_embedding))
for e in embeddings:
if e not in dataset.config().embeddings:
print('scheduling', e)
task_id = task_manager.task_id(f'Compute embedding {e.embedding} on {d.name}:{e.path}')
task_manager.execute(task_id, _compute_embedding, d.namespace, d.name, e, output_dir,
overwrite, (task_id, 0))
else:
print(f'Embedding {e.embedding} already exists for {d.name}:{e.path}. Skipping.')
del dataset
# Wait for all embeddings for each dataset to reduce the memory pressure.
task_manager.wait()
print()
print('*** Compute signals ***')
with DebugTimer('Computing signals'):
for d in config.datasets:
dataset = DatasetDuckDB(d.namespace, d.name)
# If signals are explicitly set, use only those.
signals = d.signals or []
# If signals are not explicitly set, use the media paths and config.signals.
if not signals:
if d.settings and d.settings.ui:
for path in d.settings.ui.media_paths or []:
for signal in config.signals or []:
signals.append(SignalConfig(path=path, signal=signal))
# Separate signals by path to avoid computing the same signal in parallel, which can cause
# issues with taking too much RAM.
path_signals: dict[PathTuple, list[SignalConfig]] = {}
for s in signals:
path_signals.setdefault(s.path, []).append(s)
for path, signals in path_signals.items():
for s in signals:
if s not in dataset.config().signals:
task_id = task_manager.task_id(f'Compute signal {s.signal} on {d.name}:{s.path}')
task_manager.execute(task_id, _compute_signal, d.namespace, d.name, s, output_dir,
overwrite, (task_id, 0))
else:
print(f'Signal {s.signal} already exists for {d.name}:{s.path}. Skipping.')
# Wait for all signals for each path to reduce the memory pressure.
task_manager.wait()
del dataset
print()
print('*** Compute model caches ***')
with DebugTimer('Computing model caches'):
concept_db = DiskConceptDB(output_dir)
concept_model_db = DiskConceptModelDB(concept_db)
if config.concept_model_cache_embeddings:
for concept_info in concept_db.list():
for embedding in config.concept_model_cache_embeddings:
concept_model_db.sync(
concept_info.namespace, concept_info.name, embedding_name=embedding, create=True)
print()
print('Done!')
if old_data_path:
os.environ['LILAC_DATA_PATH'] = old_data_path
def _compute_signal(namespace: str, name: str, signal_config: SignalConfig, output_dir: str,
overwrite: bool, task_step_id: TaskStepId) -> None:
os.environ['LILAC_DATA_PATH'] = output_dir
os.environ['DUCKDB_USE_VIEWS'] = '1'
# Turn off debug logging.
if 'DEBUG' in os.environ:
del os.environ['DEBUG']
compute_signal = False
if overwrite:
compute_signal = True
dataset = get_dataset(namespace, name)
if not compute_signal:
field = dataset.manifest().data_schema.get_field(signal_config.path)
signal_field = (field.fields or {}).get(signal_config.signal.key())
if not signal_field or signal_field.signal != signal_config.signal.dict():
compute_signal = True
if compute_signal:
dataset.compute_signal(signal_config.signal, signal_config.path, task_step_id)
# Free up RAM.
remove_dataset_from_cache(namespace, name)
del dataset
gc.collect()
def _compute_embedding(namespace: str, name: str, embedding_config: EmbeddingConfig,
output_dir: str, overwrite: bool, task_step_id: TaskStepId) -> None:
os.environ['LILAC_DATA_PATH'] = output_dir
os.environ['DUCKDB_USE_VIEWS'] = '1'
# Turn off debug logging.
if 'DEBUG' in os.environ:
del os.environ['DEBUG']
compute_embedding = False
if overwrite:
compute_embedding = True
dataset = get_dataset(namespace, name)
if not compute_embedding:
field = dataset.manifest().data_schema.get_field(embedding_config.path)
embedding_field = (field.fields or {}).get(embedding_config.embedding)
if not embedding_field:
compute_embedding = True
if compute_embedding:
dataset.compute_embedding(embedding_config.embedding, embedding_config.path, task_step_id)
remove_dataset_from_cache(namespace, name)
del dataset
gc.collect()
if __name__ == '__main__':
load_command()
|