"""Serves the Lilac server.""" import logging import os from typing import Any from fastapi import APIRouter, FastAPI from fastapi.responses import ORJSONResponse from fastapi.routing import APIRoute from fastapi.staticfiles import StaticFiles from . import router_concept, router_data_loader, router_dataset, router_signal, router_tasks from .config import CONFIG, data_path from .router_utils import RouteErrorHandler from .tasks import task_manager DIST_PATH = os.path.abspath(os.path.join('web', 'blueprint', 'build')) tags_metadata: list[dict[str, Any]] = [{ 'name': 'datasets', 'description': 'API for querying a dataset.', }, { 'name': 'concepts', 'description': 'API for managing concepts.', }, { 'name': 'data_loaders', 'description': 'API for loading data.', }, { 'name': 'signals', 'description': 'API for managing signals.', }] def custom_generate_unique_id(route: APIRoute) -> str: """Generate the name for the API endpoint.""" return route.name app = FastAPI( default_response_class=ORJSONResponse, generate_unique_id_function=custom_generate_unique_id, openapi_tags=tags_metadata) v1_router = APIRouter(route_class=RouteErrorHandler) v1_router.include_router(router_dataset.router, prefix='/datasets', tags=['datasets']) v1_router.include_router(router_concept.router, prefix='/concepts', tags=['concepts']) v1_router.include_router(router_data_loader.router, prefix='/data_loaders', tags=['data_loaders']) v1_router.include_router(router_signal.router, prefix='/signals', tags=['signals']) v1_router.include_router(router_tasks.router, prefix='/tasks', tags=['tasks']) app.include_router(v1_router, prefix='/api/v1') # Serve static files in production mode. app.mount('/', StaticFiles(directory=os.path.join(DIST_PATH), html=True, check_dir=False)) print('data_path=', data_path(), CONFIG) print('list:', os.listdir(data_path())) @app.on_event('shutdown') async def shutdown_event() -> None: """Kill the task manager when FastAPI shuts down.""" await task_manager().stop() class GetTasksFilter(logging.Filter): """Task filter for /tasks.""" def filter(self, record: logging.LogRecord) -> bool: """Filters out /api/v1/tasks/ from the logs.""" return record.getMessage().find('/api/v1/tasks/') == -1 logging.getLogger('uvicorn.access').addFilter(GetTasksFilter())