File size: 5,887 Bytes
8ff63e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import yaml
import random
import asyncio
from typing import Literal
from functools import cached_property

import httpx
from pydantic import BaseModel
from text_generation import AsyncClient

from spitfight.log import get_logger

logger = get_logger(__name__)


class Worker(BaseModel):
    """A worker that serves a model."""
    # Worker's container name, since we're using Overlay networks.
    hostname: str
    # For TGI, this would always be 80.
    port: int
    # User-friendly model name, e.g. "metaai/llama2-13b-chat".
    model_name: str
    # Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf".
    model_id: str
    # Whether the model worker container is good.
    status: Literal["up", "down"]

    class Config:
        keep_untouched = (cached_property,)

    @cached_property
    def url(self) -> str:
        return f"http://{self.hostname}:{self.port}"

    def get_client(self) -> AsyncClient:
        return AsyncClient(base_url=self.url)

    def audit(self) -> None:
        """Make sure the worker is running and information is as expected.

        Assumed to be called on app startup when workers are initialized.
        This method will just raise `ValueError`s if audit fails in order to
        prevent the controller from starting if anything is wrong.
        """
        try:
            response = httpx.get(self.url + "/info")
        except (httpx.ConnectError, httpx.TimeoutException) as e:
            raise ValueError(f"Could not connect to {self!r}: {e!r}")
        if response.status_code != 200:
            raise ValueError(f"Could not get /info from {self!r}.")
        info = response.json()
        if info["model_id"] != self.model_id:
            raise ValueError(f"Model name mismatch: {info['model_id']} != {self.model_id}")
        self.status = "up"
        logger.info("%s is up.", repr(self))

    async def check_status(self) -> None:
        """Check worker status and update `self.status` accordingly."""
        async with httpx.AsyncClient() as client:
            try:
                response = await client.get(self.url + "/info")
            except (httpx.ConnectError, httpx.TimeoutException) as e:
                self.status = "down"
                logger.warning("%s is down: %s", repr(self), repr(e))
                return
            if response.status_code != 200:
                self.status = "down"
                logger.warning("GET /info from %s returned %s.", repr(self), response.json())
                return
            info = response.json()
            if info["model_id"] != self.model_id:
                self.status = "down"
                logger.warning(
                    "Model name mismatch for %s: %s != %s",
                    repr(self),
                    info["model_id"],
                    self.model_id,
                )
                return
        logger.info("%s is up.", repr(self))
        self.status = "up"


class WorkerService:
    """A service that manages model serving workers.

    Worker objects are only created once and shared across the
    entire application. Especially, changing the status of a worker
    will immediately take effect on the result of `choose_two`.

    Attributes:
        workers (list[Worker]): The list of workers.
    """

    def __init__(self, compose_files: list[str]) -> None:
        """Initialize the worker service."""
        self.workers: list[Worker] = []
        worker_model_names = set()
        for compose_file in compose_files:
            spec = yaml.safe_load(open(compose_file))
            for model_name, service_spec in spec["services"].items():
                command = service_spec["command"]
                for i, cmd in enumerate(command):
                    if cmd == "--model-id":
                        model_id = command[i + 1]
                        break
                else:
                    raise ValueError(f"Could not find model ID in {command!r}")
                worker_model_names.add(model_name)
                worker = Worker(
                    hostname=service_spec["container_name"],
                    port=80,
                    model_name=model_name,
                    model_id=model_id,
                    status="down",
                )
                worker.audit()
                self.workers.append(worker)

        if len(worker_model_names) != len(self.workers):
            raise ValueError("Model names must be unique.")

    def get_worker(self, model_name: str) -> Worker:
        """Get a worker by model name."""
        for worker in self.workers:
            if worker.model_name == model_name:
                if worker.status == "down":
                    # This is an unfortunate case where, when the two models were chosen,
                    # the worker was up, but after that went down before the request
                    # completed. We'll just raise a 500 internal error and have the user
                    # try again. This won't be common.
                    raise RuntimeError(f"The worker with model name {model_name} is down.")
                return worker
        raise ValueError(f"Worker with model name {model_name} does not exist.")

    def choose_two(self) -> tuple[Worker, Worker]:
        """Choose two different workers.

        Good place to use the Strategy Pattern when we want to
        implement different strategies for choosing workers.
        """
        live_workers = [worker for worker in self.workers if worker.status == "up"]
        if len(live_workers) < 2:
            raise ValueError("Not enough live workers to choose from.")
        worker_a, worker_b = random.sample(live_workers, 2)
        return worker_a, worker_b

    async def check_workers(self) -> None:
        """Check the status of all workers."""
        await asyncio.gather(*[worker.check_status() for worker in self.workers])