Spaces:
Runtime error
Runtime error
"""Router for the signal registry.""" | |
import math | |
from typing import Annotated, Any, Optional | |
from fastapi import APIRouter, Depends | |
from pydantic import BaseModel, validator | |
from .auth import UserInfo, get_session_user | |
from .router_utils import RouteErrorHandler, server_compute_concept | |
from .schema import Field, SignalInputType | |
from .signal import SIGNAL_REGISTRY, Signal, TextEmbeddingSignal, resolve_signal | |
from .signals.concept_scorer import ConceptSignal | |
router = APIRouter(route_class=RouteErrorHandler) | |
EMBEDDING_SORT_PRIORITIES = ['gte-small', 'gte-base', 'openai', 'sbert'] | |
class SignalInfo(BaseModel): | |
"""Information about a signal.""" | |
name: str | |
input_type: SignalInputType | |
json_schema: dict[str, Any] | |
def get_signals() -> list[SignalInfo]: | |
"""List the signals.""" | |
return [ | |
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema()) | |
for s in SIGNAL_REGISTRY.values() | |
if not issubclass(s, TextEmbeddingSignal) | |
] | |
def get_embeddings() -> list[SignalInfo]: | |
"""List the embeddings.""" | |
embedding_infos = [ | |
SignalInfo(name=s.name, input_type=s.input_type, json_schema=s.schema()) | |
for s in SIGNAL_REGISTRY.values() | |
if issubclass(s, TextEmbeddingSignal) | |
] | |
# Sort the embedding infos by priority. | |
embedding_infos = sorted( | |
embedding_infos, | |
key=lambda s: EMBEDDING_SORT_PRIORITIES.index(s.name) | |
if s.name in EMBEDDING_SORT_PRIORITIES else math.inf) | |
return embedding_infos | |
class SignalComputeOptions(BaseModel): | |
"""The request for the standalone compute signal endpoint.""" | |
signal: Signal | |
# The inputs to compute. | |
inputs: list[str] | |
def parse_signal(cls, signal: dict) -> Signal: | |
"""Parse a signal to its specific subclass instance.""" | |
return resolve_signal(signal) | |
class SignalComputeResponse(BaseModel): | |
"""The response for the standalone compute signal endpoint.""" | |
items: list[Optional[Any]] | |
def compute( | |
options: SignalComputeOptions, | |
user: Annotated[Optional[UserInfo], Depends(get_session_user)]) -> SignalComputeResponse: | |
"""Compute a signal over a set of inputs.""" | |
signal = options.signal | |
if isinstance(signal, ConceptSignal): | |
result = server_compute_concept(signal, options.inputs, user) | |
else: | |
signal.setup() | |
result = list(signal.compute(options.inputs)) | |
return SignalComputeResponse(items=result) | |
class SignalSchemaOptions(BaseModel): | |
"""The request for the signal schema endpoint.""" | |
signal: Signal | |
def parse_signal(cls, signal: dict) -> Signal: | |
"""Parse a signal to its specific subclass instance.""" | |
return resolve_signal(signal) | |
class SignalSchemaResponse(BaseModel): | |
"""The response for the signal schema endpoint.""" | |
fields: Field | |
def schema(options: SignalSchemaOptions) -> SignalSchemaResponse: | |
"""Get the schema for a signal.""" | |
signal = options.signal | |
return SignalSchemaResponse(fields=signal.fields()) | |