Spaces:
Running
Running
File size: 4,705 Bytes
8ff63e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import json
import uvicorn
from pydantic import BaseSettings
from fastapi import FastAPI, Depends
from fastapi.responses import StreamingResponse
from fastapi.exceptions import HTTPException
from text_generation.errors import OverloadedError, UnknownError, ValidationError
from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
from spitfight.colosseum.common import (
COLOSSEUM_PROMPT_ROUTE,
COLOSSEUM_RESP_VOTE_ROUTE,
COLOSSEUM_ENERGY_VOTE_ROUTE,
COLOSSEUM_HEALTH_ROUTE,
PromptRequest,
ResponseVoteRequest,
ResponseVoteResponse,
EnergyVoteRequest,
EnergyVoteResponse,
)
from spitfight.colosseum.controller.controller import (
Controller,
init_global_controller,
get_global_controller,
)
from spitfight.utils import prepend_generator
class ControllerConfig(BaseSettings):
"""Controller settings automatically loaded from environment variables."""
# Controller
background_task_interval: int = 300
max_num_req_states: int = 10000
req_state_expiration_time: int = 600
compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"]
# Logging
log_dir: str = "/logs"
controller_log_file: str = "controller.log"
request_log_file: str = "requests.log"
uvicorn_log_file: str = "uvicorn.log"
# Generation
max_new_tokens: int = 512
do_sample: bool = True
temperature: float = 1.0
repetition_penalty: float = 1.0
top_k: int = 50
top_p: float = 0.95
app = FastAPI()
settings = ControllerConfig()
logger = get_logger("spitfight.colosseum.controller.router")
@app.on_event("startup")
async def startup_event():
init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file))
init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file))
init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file))
init_global_controller(settings)
@app.on_event("shutdown")
async def shutdown_event():
get_global_controller().shutdown()
shutdown_queued_root_loggers()
@app.post(COLOSSEUM_PROMPT_ROUTE)
async def prompt(
request: PromptRequest,
controller: Controller = Depends(get_global_controller),
):
generator = controller.prompt(request.request_id, request.prompt, request.model_index)
# First try to get the first token in order to catch TGI errors.
try:
first_token = await generator.__anext__()
except OverloadedError:
name = controller.request_states[request.request_id].model_names[request.model_index]
logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request))
raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.")
except ValidationError as e:
logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request))
raise HTTPException(status_code=422, detail=str(e))
except StopAsyncIteration:
logger.info("TGI returned empty response. Failed request: %s", repr(request))
return StreamingResponse(
iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]),
)
except UnknownError as e:
logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request))
raise HTTPException(status_code=500, detail=str(e))
return StreamingResponse(prepend_generator(first_token, generator))
@app.post(COLOSSEUM_RESP_VOTE_ROUTE, response_model=ResponseVoteResponse)
async def response_vote(
request: ResponseVoteRequest,
controller: Controller = Depends(get_global_controller),
):
if (state := controller.response_vote(request.request_id, request.victory_index)) is None:
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
return ResponseVoteResponse(
energy_consumptions=state.energy_consumptions,
model_names=state.model_names,
)
@app.post(COLOSSEUM_ENERGY_VOTE_ROUTE, response_model=EnergyVoteResponse)
async def energy_vote(
request: EnergyVoteRequest,
controller: Controller = Depends(get_global_controller),
):
if (state := controller.energy_vote(request.request_id, request.is_worth)) is None:
raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
return EnergyVoteResponse(model_names=state.model_names)
@app.get(COLOSSEUM_HEALTH_ROUTE)
async def health():
return "OK"
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", log_config=None)
|