FIRE / src /serve /huggingface_api_worker.py
zhangbofei
fix: src
2238fe2
raw
history blame
12.6 kB
"""
A model worker that calls huggingface inference endpoint.
Register models in a JSON file with the following format:
{
"falcon-180b-chat": {
"model_name": "falcon-180B-chat",
"api_base": "https://api-inference.huggingface.co/models",
"model_path": "tiiuae/falcon-180B-chat",
"token": "hf_XXX",
"context_length": 2048
},
"zephyr-7b-beta": {
"model_name": "zephyr-7b-beta",
"model_path": "",
"api_base": "xxx",
"token": "hf_XXX",
"context_length": 4096
}
}
"model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
"""
import argparse
import asyncio
import json
import uuid
import os
from typing import List, Optional
import requests
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from huggingface_hub import InferenceClient
from src.constants import SERVER_ERROR_MSG, ErrorCode
from src.serve.base_model_worker import BaseModelWorker
from src.utils import build_logger
worker_id = str(uuid.uuid4())[:8]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
workers = []
worker_map = {}
app = FastAPI()
# reference to
# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
def get_gen_kwargs(
params,
seed: Optional[int] = None,
):
stop = params.get("stop", None)
if isinstance(stop, list):
stop_sequences = stop
elif isinstance(stop, str):
stop_sequences = [stop]
else:
stop_sequences = []
gen_kwargs = {
"do_sample": True,
"return_full_text": bool(params.get("echo", False)),
"max_new_tokens": int(params.get("max_new_tokens", 256)),
"top_p": float(params.get("top_p", 1.0)),
"temperature": float(params.get("temperature", 1.0)),
"stop_sequences": stop_sequences,
"repetition_penalty": float(params.get("repetition_penalty", 1.0)),
"top_k": params.get("top_k", None),
"seed": seed,
}
if gen_kwargs["top_p"] == 1:
gen_kwargs["top_p"] = 0.9999999
if gen_kwargs["top_p"] == 0:
gen_kwargs.pop("top_p")
if gen_kwargs["temperature"] == 0:
gen_kwargs.pop("temperature")
gen_kwargs["do_sample"] = False
return gen_kwargs
def could_be_stop(text, stop):
for s in stop:
if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
return True
return False
class HuggingfaceApiWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
api_base: str,
token: str,
context_length: int,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
conv_template: Optional[str] = None,
seed: Optional[int] = None,
**kwargs,
):
super().__init__(
controller_addr,
worker_addr,
worker_id,
model_path,
model_names,
limit_worker_concurrency,
conv_template=conv_template,
)
self.model_path = model_path
self.api_base = api_base
self.token = token
self.context_len = context_length
self.seed = seed
logger.info(
f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
)
if not no_register:
self.init_heart_beat()
def count_token(self, params):
# No tokenizer here
ret = {
"count": 0,
"error_code": 0,
}
return ret
def generate_stream_gate(self, params):
self.call_ct += 1
prompt = params["prompt"]
gen_kwargs = get_gen_kwargs(params, seed=self.seed)
stop = gen_kwargs["stop_sequences"]
if "falcon" in self.model_path and "chat" in self.model_path:
stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
stop = list(set(stop))
gen_kwargs["stop_sequences"] = stop
logger.info(f"prompt: {prompt}")
logger.info(f"gen_kwargs: {gen_kwargs}")
try:
if self.model_path == "":
url = f"{self.api_base}"
else:
url = f"{self.api_base}/{self.model_path}"
client = InferenceClient(url, token=self.token)
res = client.text_generation(
prompt, stream=True, details=True, **gen_kwargs
)
reason = None
text = ""
for chunk in res:
if chunk.token.special:
continue
text += chunk.token.text
s = next((x for x in stop if text.endswith(x)), None)
if s is not None:
text = text[: -len(s)]
reason = "stop"
break
if could_be_stop(text, stop):
continue
if (
chunk.details is not None
and chunk.details.finish_reason is not None
):
reason = chunk.details.finish_reason
if reason not in ["stop", "length"]:
reason = None
ret = {
"text": text,
"error_code": 0,
"finish_reason": reason,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
yield json.dumps(ret).encode() + b"\0"
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def get_embeddings(self, params):
raise NotImplementedError()
def release_worker_semaphore(worker):
worker.semaphore.release()
def acquire_worker_semaphore(worker):
if worker.semaphore is None:
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
return worker.semaphore.acquire()
def create_background_tasks(worker):
background_tasks = BackgroundTasks()
background_tasks.add_task(lambda: release_worker_semaphore(worker))
return background_tasks
@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
worker = worker_map[params["model"]]
await acquire_worker_semaphore(worker)
generator = worker.generate_stream_gate(params)
background_tasks = create_background_tasks(worker)
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_generate")
async def api_generate(request: Request):
params = await request.json()
worker = worker_map[params["model"]]
await acquire_worker_semaphore(worker)
output = worker.generate_gate(params)
release_worker_semaphore(worker)
return JSONResponse(output)
@app.post("/worker_get_embeddings")
async def api_get_embeddings(request: Request):
params = await request.json()
worker = worker_map[params["model"]]
await acquire_worker_semaphore(worker)
embedding = worker.get_embeddings(params)
release_worker_semaphore(worker)
return JSONResponse(content=embedding)
@app.post("/worker_get_status")
async def api_get_status(request: Request):
return {
"model_names": [m for w in workers for m in w.model_names],
"speed": 1,
"queue_length": sum([w.get_queue_length() for w in workers]),
}
@app.post("/count_token")
async def api_count_token(request: Request):
params = await request.json()
worker = worker_map[params["model"]]
return worker.count_token(params)
@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
params = await request.json()
worker = worker_map[params["model"]]
return worker.get_conv_template()
@app.post("/model_details")
async def api_model_details(request: Request):
params = await request.json()
worker = worker_map[params["model"]]
return {"context_length": worker.context_len}
def create_huggingface_api_worker():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
# all model-related parameters are listed in --model-info-file
parser.add_argument(
"--model-info-file",
type=str,
required=True,
help="Huggingface API model's info file path",
)
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--seed",
type=int,
default=None,
help="Overwrite the random seed for each generation.",
)
parser.add_argument(
"--ssl",
action="store_true",
required=False,
default=False,
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
)
args = parser.parse_args()
with open(args.model_info_file, "r", encoding="UTF-8") as f:
model_info = json.load(f)
logger.info(f"args: {args}")
model_path_list = []
api_base_list = []
token_list = []
context_length_list = []
model_names_list = []
conv_template_list = []
for m in model_info:
model_path_list.append(model_info[m]["model_path"])
api_base_list.append(model_info[m]["api_base"])
token_list.append(model_info[m]["token"])
context_length = model_info[m]["context_length"]
model_names = model_info[m].get("model_names", [m.split("/")[-1]])
if isinstance(model_names, str):
model_names = [model_names]
conv_template = model_info[m].get("conv_template", None)
context_length_list.append(context_length)
model_names_list.append(model_names)
conv_template_list.append(conv_template)
logger.info(f"Model paths: {model_path_list}")
logger.info(f"API bases: {api_base_list}")
logger.info(f"Tokens: {token_list}")
logger.info(f"Context lengths: {context_length_list}")
logger.info(f"Model names: {model_names_list}")
logger.info(f"Conv templates: {conv_template_list}")
for (
model_names,
conv_template,
model_path,
api_base,
token,
context_length,
) in zip(
model_names_list,
conv_template_list,
model_path_list,
api_base_list,
token_list,
context_length_list,
):
m = HuggingfaceApiWorker(
args.controller_address,
args.worker_address,
worker_id,
model_path,
api_base,
token,
context_length,
model_names,
args.limit_worker_concurrency,
no_register=args.no_register,
conv_template=conv_template,
seed=args.seed,
)
workers.append(m)
for name in model_names:
worker_map[name] = m
# register all the models
url = args.controller_address + "/register_worker"
data = {
"worker_name": workers[0].worker_addr,
"check_heart_beat": not args.no_register,
"worker_status": {
"model_names": [m for w in workers for m in w.model_names],
"speed": 1,
"queue_length": sum([w.get_queue_length() for w in workers]),
},
}
r = requests.post(url, json=data)
assert r.status_code == 200
return args, workers
if __name__ == "__main__":
args, workers = create_huggingface_api_worker()
if args.ssl:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
ssl_keyfile=os.environ["SSL_KEYFILE"],
ssl_certfile=os.environ["SSL_CERTFILE"],
)
else:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")