|
import asyncio |
|
import threading |
|
import time |
|
from typing import List |
|
|
|
from fastapi import FastAPI, Request, BackgroundTasks |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
import requests |
|
|
|
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL |
|
from fastchat.conversation import Conversation |
|
from fastchat.utils import pretty_print_semaphore, build_logger |
|
|
|
|
|
worker = None |
|
logger = None |
|
|
|
app = FastAPI() |
|
|
|
|
|
def heart_beat_worker(obj): |
|
while True: |
|
time.sleep(WORKER_HEART_BEAT_INTERVAL) |
|
obj.send_heart_beat() |
|
|
|
|
|
class BaseModelWorker: |
|
def __init__( |
|
self, |
|
controller_addr: str, |
|
worker_addr: str, |
|
worker_id: str, |
|
model_path: str, |
|
model_names: List[str], |
|
limit_worker_concurrency: int, |
|
conv_template: str = None, |
|
): |
|
global logger, worker |
|
|
|
self.controller_addr = controller_addr |
|
self.worker_addr = worker_addr |
|
self.worker_id = worker_id |
|
if model_path.endswith("/"): |
|
model_path = model_path[:-1] |
|
self.model_names = model_names or [model_path.split("/")[-1]] |
|
self.limit_worker_concurrency = limit_worker_concurrency |
|
self.conv = self.make_conv_template(conv_template, model_path) |
|
self.conv.sep_style = int(self.conv.sep_style) |
|
self.tokenizer = None |
|
self.context_len = None |
|
self.call_ct = 0 |
|
self.semaphore = None |
|
|
|
self.heart_beat_thread = None |
|
|
|
if logger is None: |
|
logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") |
|
if worker is None: |
|
worker = self |
|
|
|
def make_conv_template( |
|
self, |
|
conv_template: str = None, |
|
model_path: str = None, |
|
) -> Conversation: |
|
""" |
|
can be overrided to costomize the conversation template for different model workers. |
|
""" |
|
from fastchat.conversation import get_conv_template |
|
from fastchat.model.model_adapter import get_conversation_template |
|
|
|
if conv_template: |
|
conv = get_conv_template(conv_template) |
|
else: |
|
conv = get_conversation_template(model_path) |
|
return conv |
|
|
|
def init_heart_beat(self): |
|
self.register_to_controller() |
|
self.heart_beat_thread = threading.Thread( |
|
target=heart_beat_worker, |
|
args=(self,), |
|
daemon=True, |
|
) |
|
self.heart_beat_thread.start() |
|
|
|
def register_to_controller(self): |
|
logger.info("Register to controller") |
|
|
|
url = self.controller_addr + "/register_worker" |
|
data = { |
|
"worker_name": self.worker_addr, |
|
"check_heart_beat": True, |
|
"worker_status": self.get_status(), |
|
} |
|
r = requests.post(url, json=data) |
|
assert r.status_code == 200 |
|
|
|
def send_heart_beat(self): |
|
logger.info( |
|
f"Send heart beat. Models: {self.model_names}. " |
|
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " |
|
f"call_ct: {self.call_ct}. " |
|
f"worker_id: {self.worker_id}. " |
|
) |
|
|
|
url = self.controller_addr + "/receive_heart_beat" |
|
|
|
while True: |
|
try: |
|
ret = requests.post( |
|
url, |
|
json={ |
|
"worker_name": self.worker_addr, |
|
"queue_length": self.get_queue_length(), |
|
}, |
|
timeout=5, |
|
) |
|
exist = ret.json()["exist"] |
|
break |
|
except (requests.exceptions.RequestException, KeyError) as e: |
|
logger.error(f"heart beat error: {e}") |
|
time.sleep(5) |
|
|
|
if not exist: |
|
self.register_to_controller() |
|
|
|
def get_queue_length(self): |
|
if self.semaphore is None: |
|
return 0 |
|
else: |
|
sempahore_value = ( |
|
self.semaphore._value |
|
if self.semaphore._value is not None |
|
else self.limit_worker_concurrency |
|
) |
|
waiter_count = ( |
|
0 if self.semaphore._waiters is None else len(self.semaphore._waiters) |
|
) |
|
return self.limit_worker_concurrency - sempahore_value + waiter_count |
|
|
|
def get_status(self): |
|
return { |
|
"model_names": self.model_names, |
|
"speed": 1, |
|
"queue_length": self.get_queue_length(), |
|
} |
|
|
|
def count_token(self, params): |
|
prompt = params["prompt"] |
|
|
|
try: |
|
input_ids = self.tokenizer(prompt).input_ids |
|
input_echo_len = len(input_ids) |
|
except TypeError: |
|
input_echo_len = self.tokenizer.num_tokens(prompt) |
|
|
|
ret = { |
|
"count": input_echo_len, |
|
"error_code": 0, |
|
} |
|
return ret |
|
|
|
def get_conv_template(self): |
|
return {"conv": self.conv} |
|
|
|
def generate_stream_gate(self, params): |
|
raise NotImplementedError |
|
|
|
def generate_gate(self, params): |
|
raise NotImplementedError |
|
|
|
def get_embeddings(self, params): |
|
raise NotImplementedError |
|
|
|
|
|
def release_worker_semaphore(): |
|
worker.semaphore.release() |
|
|
|
|
|
def acquire_worker_semaphore(): |
|
if worker.semaphore is None: |
|
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) |
|
return worker.semaphore.acquire() |
|
|
|
|
|
def create_background_tasks(): |
|
background_tasks = BackgroundTasks() |
|
background_tasks.add_task(release_worker_semaphore) |
|
return background_tasks |
|
|
|
|
|
@app.post("/worker_generate_stream") |
|
async def api_generate_stream(request: Request): |
|
params = await request.json() |
|
await acquire_worker_semaphore() |
|
generator = worker.generate_stream_gate(params) |
|
background_tasks = create_background_tasks() |
|
return StreamingResponse(generator, background=background_tasks) |
|
|
|
|
|
@app.post("/worker_generate") |
|
async def api_generate(request: Request): |
|
params = await request.json() |
|
await acquire_worker_semaphore() |
|
output = await asyncio.to_thread(worker.generate_gate, params) |
|
release_worker_semaphore() |
|
return JSONResponse(output) |
|
|
|
|
|
@app.post("/worker_get_embeddings") |
|
async def api_get_embeddings(request: Request): |
|
params = await request.json() |
|
await acquire_worker_semaphore() |
|
embedding = worker.get_embeddings(params) |
|
release_worker_semaphore() |
|
return JSONResponse(content=embedding) |
|
|
|
|
|
@app.post("/worker_get_status") |
|
async def api_get_status(request: Request): |
|
return worker.get_status() |
|
|
|
|
|
@app.post("/count_token") |
|
async def api_count_token(request: Request): |
|
params = await request.json() |
|
return worker.count_token(params) |
|
|
|
|
|
@app.post("/worker_get_conv_template") |
|
async def api_get_conv(request: Request): |
|
return worker.get_conv_template() |
|
|
|
|
|
@app.post("/model_details") |
|
async def api_model_details(request: Request): |
|
return {"context_length": worker.context_len} |
|
|