Spaces:
Running
Running
import yaml | |
import random | |
import asyncio | |
from typing import Literal | |
from functools import cached_property | |
import httpx | |
from pydantic import BaseModel | |
from text_generation import AsyncClient | |
from spitfight.log import get_logger | |
logger = get_logger(__name__) | |
class Worker(BaseModel): | |
"""A worker that serves a model.""" | |
# Worker's container name, since we're using Overlay networks. | |
hostname: str | |
# For TGI, this would always be 80. | |
port: int | |
# User-friendly model name, e.g. "metaai/llama2-13b-chat". | |
model_name: str | |
# Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf". | |
model_id: str | |
# Whether the model worker container is good. | |
status: Literal["up", "down"] | |
class Config: | |
keep_untouched = (cached_property,) | |
def url(self) -> str: | |
return f"http://{self.hostname}:{self.port}" | |
def get_client(self) -> AsyncClient: | |
return AsyncClient(base_url=self.url) | |
def audit(self) -> None: | |
"""Make sure the worker is running and information is as expected. | |
Assumed to be called on app startup when workers are initialized. | |
This method will just raise `ValueError`s if audit fails in order to | |
prevent the controller from starting if anything is wrong. | |
""" | |
try: | |
response = httpx.get(self.url + "/info") | |
except (httpx.ConnectError, httpx.TimeoutException) as e: | |
raise ValueError(f"Could not connect to {self!r}: {e!r}") | |
if response.status_code != 200: | |
raise ValueError(f"Could not get /info from {self!r}.") | |
info = response.json() | |
if info["model_id"] != self.model_id: | |
raise ValueError(f"Model name mismatch: {info['model_id']} != {self.model_id}") | |
self.status = "up" | |
logger.info("%s is up.", repr(self)) | |
async def check_status(self) -> None: | |
"""Check worker status and update `self.status` accordingly.""" | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.get(self.url + "/info") | |
except (httpx.ConnectError, httpx.TimeoutException) as e: | |
self.status = "down" | |
logger.warning("%s is down: %s", repr(self), repr(e)) | |
return | |
if response.status_code != 200: | |
self.status = "down" | |
logger.warning("GET /info from %s returned %s.", repr(self), response.json()) | |
return | |
info = response.json() | |
if info["model_id"] != self.model_id: | |
self.status = "down" | |
logger.warning( | |
"Model name mismatch for %s: %s != %s", | |
repr(self), | |
info["model_id"], | |
self.model_id, | |
) | |
return | |
logger.info("%s is up.", repr(self)) | |
self.status = "up" | |
class WorkerService: | |
"""A service that manages model serving workers. | |
Worker objects are only created once and shared across the | |
entire application. Especially, changing the status of a worker | |
will immediately take effect on the result of `choose_two`. | |
Attributes: | |
workers (list[Worker]): The list of workers. | |
""" | |
def __init__(self, compose_files: list[str]) -> None: | |
"""Initialize the worker service.""" | |
self.workers: list[Worker] = [] | |
worker_model_names = set() | |
for compose_file in compose_files: | |
spec = yaml.safe_load(open(compose_file)) | |
for model_name, service_spec in spec["services"].items(): | |
command = service_spec["command"] | |
for i, cmd in enumerate(command): | |
if cmd == "--model-id": | |
model_id = command[i + 1] | |
break | |
else: | |
raise ValueError(f"Could not find model ID in {command!r}") | |
worker_model_names.add(model_name) | |
worker = Worker( | |
hostname=service_spec["container_name"], | |
port=80, | |
model_name=model_name, | |
model_id=model_id, | |
status="down", | |
) | |
worker.audit() | |
self.workers.append(worker) | |
if len(worker_model_names) != len(self.workers): | |
raise ValueError("Model names must be unique.") | |
def get_worker(self, model_name: str) -> Worker: | |
"""Get a worker by model name.""" | |
for worker in self.workers: | |
if worker.model_name == model_name: | |
if worker.status == "down": | |
# This is an unfortunate case where, when the two models were chosen, | |
# the worker was up, but after that went down before the request | |
# completed. We'll just raise a 500 internal error and have the user | |
# try again. This won't be common. | |
raise RuntimeError(f"The worker with model name {model_name} is down.") | |
return worker | |
raise ValueError(f"Worker with model name {model_name} does not exist.") | |
def choose_two(self) -> tuple[Worker, Worker]: | |
"""Choose two different workers. | |
Good place to use the Strategy Pattern when we want to | |
implement different strategies for choosing workers. | |
""" | |
live_workers = [worker for worker in self.workers if worker.status == "up"] | |
if len(live_workers) < 2: | |
raise ValueError("Not enough live workers to choose from.") | |
worker_a, worker_b = random.sample(live_workers, 2) | |
return worker_a, worker_b | |
async def check_workers(self) -> None: | |
"""Check the status of all workers.""" | |
await asyncio.gather(*[worker.check_status() for worker in self.workers]) | |