nikhil_no_persistent / lilac /router_signal.py
nsthorat-lilac's picture
Duplicate from lilacai/nikhil_staging
bfc0ec6
"""Router for the signal registry."""
import math
from typing import Annotated, Any, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel, validator
from .auth import UserInfo, get_session_user
from .router_utils import RouteErrorHandler, server_compute_concept
from .schema import Field, SignalInputType
from .signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal
from .signals.concept_scorer import ConceptSignal
router = APIRouter(route_class=RouteErrorHandler)
EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert']
class SignalInfo(BaseModel):
"""Information about a signal."""
name: str
input_type: SignalInputType
json_schema: dict[str, Any]
@router.get('/', response_model_exclude_none=True)
def get_signals() -> list[SignalInfo]:
"""List the signals."""
return [
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
for s in SIGNAL_REGISTRY.values()
if not issubclass(s, TextEmbeddingSignal)
]
@router.get('/embeddings', response_model_exclude_none=True)
def get_embeddings() -> list[SignalInfo]:
"""List the embeddings."""
embedding_infos = [
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema())
for s in SIGNAL_REGISTRY.values()
if issubclass(s, TextEmbeddingSignal)
]
# Sort the embedding infos by priority.
embedding_infos = sorted(
embedding_infos,
key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name)
if s.name in EMBEDDING_SORT_PRIORITIES else math.inf)
return embedding_infos
class SignalComputeOptions(BaseModel):
"""The request for the standalone compute signal endpoint."""
signal: Signal
# The inputs to compute.
inputs: list[str]
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
class SignalComputeResponse(BaseModel):
"""The response for the standalone compute signal endpoint."""
items: list[Optional[Any]]
@router.post('/compute', response_model_exclude_none=True)
def compute(
options: SignalComputeOptions,
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse:
"""Compute a signal over a set of inputs."""
signal = options.signal
if isinstance(signal, ConceptSignal):
result = server_compute_concept(signal, options.inputs, user)
else:
signal.setup()
result = list(signal.compute(options.inputs))
return SignalComputeResponse(items=result)
class SignalSchemaOptions(BaseModel):
"""The request for the signal schema endpoint."""
signal: Signal
@validator('signal', pre=True)
def parse_signal(cls, signal: dict) -> Signal:
"""Parse a signal to its specific subclass instance."""
return resolve_signal(signal)
class SignalSchemaResponse(BaseModel):
"""The response for the signal schema endpoint."""
fields: Field
@router.post('/schema', response_model_exclude_none=True)
def schema(options: SignalSchemaOptions) -> SignalSchemaResponse:
"""Get the schema for a signal."""
signal = options.signal
return SignalSchemaResponse(fields=signal.fields())