|
""" |
|
A model worker that calls huggingface inference endpoint. |
|
|
|
Register models in a JSON file with the following format: |
|
{ |
|
"falcon-180b-chat": { |
|
"model_name": "falcon-180B-chat", |
|
"api_base": "https://api-inference.huggingface.co/models", |
|
"model_path": "tiiuae/falcon-180B-chat", |
|
"token": "hf_XXX", |
|
"context_length": 2048 |
|
}, |
|
"zephyr-7b-beta": { |
|
"model_name": "zephyr-7b-beta", |
|
"model_path": "", |
|
"api_base": "xxx", |
|
"token": "hf_XXX", |
|
"context_length": 4096 |
|
} |
|
} |
|
|
|
"model_path", "api_base", "token", and "context_length" are necessary, while others are optional. |
|
""" |
|
import argparse |
|
import asyncio |
|
import json |
|
import uuid |
|
import os |
|
from typing import List, Optional |
|
|
|
import requests |
|
import uvicorn |
|
from fastapi import BackgroundTasks, FastAPI, Request |
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
from huggingface_hub import InferenceClient |
|
|
|
from src.constants import SERVER_ERROR_MSG, ErrorCode |
|
from src.serve.base_model_worker import BaseModelWorker |
|
from src.utils import build_logger |
|
|
|
worker_id = str(uuid.uuid4())[:8] |
|
logger = build_logger("model_worker", f"model_worker_{worker_id}.log") |
|
|
|
workers = [] |
|
worker_map = {} |
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
def get_gen_kwargs( |
|
params, |
|
seed: Optional[int] = None, |
|
): |
|
stop = params.get("stop", None) |
|
if isinstance(stop, list): |
|
stop_sequences = stop |
|
elif isinstance(stop, str): |
|
stop_sequences = [stop] |
|
else: |
|
stop_sequences = [] |
|
gen_kwargs = { |
|
"do_sample": True, |
|
"return_full_text": bool(params.get("echo", False)), |
|
"max_new_tokens": int(params.get("max_new_tokens", 256)), |
|
"top_p": float(params.get("top_p", 1.0)), |
|
"temperature": float(params.get("temperature", 1.0)), |
|
"stop_sequences": stop_sequences, |
|
"repetition_penalty": float(params.get("repetition_penalty", 1.0)), |
|
"top_k": params.get("top_k", None), |
|
"seed": seed, |
|
} |
|
if gen_kwargs["top_p"] == 1: |
|
gen_kwargs["top_p"] = 0.9999999 |
|
if gen_kwargs["top_p"] == 0: |
|
gen_kwargs.pop("top_p") |
|
if gen_kwargs["temperature"] == 0: |
|
gen_kwargs.pop("temperature") |
|
gen_kwargs["do_sample"] = False |
|
return gen_kwargs |
|
|
|
|
|
def could_be_stop(text, stop): |
|
for s in stop: |
|
if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): |
|
return True |
|
return False |
|
|
|
|
|
class HuggingfaceApiWorker(BaseModelWorker): |
|
def __init__( |
|
self, |
|
controller_addr: str, |
|
worker_addr: str, |
|
worker_id: str, |
|
model_path: str, |
|
api_base: str, |
|
token: str, |
|
context_length: int, |
|
model_names: List[str], |
|
limit_worker_concurrency: int, |
|
no_register: bool, |
|
conv_template: Optional[str] = None, |
|
seed: Optional[int] = None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
controller_addr, |
|
worker_addr, |
|
worker_id, |
|
model_path, |
|
model_names, |
|
limit_worker_concurrency, |
|
conv_template=conv_template, |
|
) |
|
|
|
self.model_path = model_path |
|
self.api_base = api_base |
|
self.token = token |
|
self.context_len = context_length |
|
self.seed = seed |
|
|
|
logger.info( |
|
f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." |
|
) |
|
|
|
if not no_register: |
|
self.init_heart_beat() |
|
|
|
def count_token(self, params): |
|
|
|
ret = { |
|
"count": 0, |
|
"error_code": 0, |
|
} |
|
return ret |
|
|
|
def generate_stream_gate(self, params): |
|
self.call_ct += 1 |
|
|
|
prompt = params["prompt"] |
|
gen_kwargs = get_gen_kwargs(params, seed=self.seed) |
|
stop = gen_kwargs["stop_sequences"] |
|
if "falcon" in self.model_path and "chat" in self.model_path: |
|
stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) |
|
stop = list(set(stop)) |
|
gen_kwargs["stop_sequences"] = stop |
|
|
|
logger.info(f"prompt: {prompt}") |
|
logger.info(f"gen_kwargs: {gen_kwargs}") |
|
|
|
try: |
|
if self.model_path == "": |
|
url = f"{self.api_base}" |
|
else: |
|
url = f"{self.api_base}/{self.model_path}" |
|
client = InferenceClient(url, token=self.token) |
|
res = client.text_generation( |
|
prompt, stream=True, details=True, **gen_kwargs |
|
) |
|
|
|
reason = None |
|
text = "" |
|
for chunk in res: |
|
if chunk.token.special: |
|
continue |
|
text += chunk.token.text |
|
|
|
s = next((x for x in stop if text.endswith(x)), None) |
|
if s is not None: |
|
text = text[: -len(s)] |
|
reason = "stop" |
|
break |
|
if could_be_stop(text, stop): |
|
continue |
|
if ( |
|
chunk.details is not None |
|
and chunk.details.finish_reason is not None |
|
): |
|
reason = chunk.details.finish_reason |
|
if reason not in ["stop", "length"]: |
|
reason = None |
|
ret = { |
|
"text": text, |
|
"error_code": 0, |
|
"finish_reason": reason, |
|
} |
|
yield json.dumps(ret).encode() + b"\0" |
|
except Exception as e: |
|
ret = { |
|
"text": f"{SERVER_ERROR_MSG}\n\n({e})", |
|
"error_code": ErrorCode.INTERNAL_ERROR, |
|
} |
|
yield json.dumps(ret).encode() + b"\0" |
|
|
|
def generate_gate(self, params): |
|
for x in self.generate_stream_gate(params): |
|
pass |
|
return json.loads(x[:-1].decode()) |
|
|
|
def get_embeddings(self, params): |
|
raise NotImplementedError() |
|
|
|
|
|
def release_worker_semaphore(worker): |
|
worker.semaphore.release() |
|
|
|
|
|
def acquire_worker_semaphore(worker): |
|
if worker.semaphore is None: |
|
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) |
|
return worker.semaphore.acquire() |
|
|
|
|
|
def create_background_tasks(worker): |
|
background_tasks = BackgroundTasks() |
|
background_tasks.add_task(lambda: release_worker_semaphore(worker)) |
|
return background_tasks |
|
|
|
|
|
@app.post("/worker_generate_stream") |
|
async def api_generate_stream(request: Request): |
|
params = await request.json() |
|
worker = worker_map[params["model"]] |
|
await acquire_worker_semaphore(worker) |
|
generator = worker.generate_stream_gate(params) |
|
background_tasks = create_background_tasks(worker) |
|
return StreamingResponse(generator, background=background_tasks) |
|
|
|
|
|
@app.post("/worker_generate") |
|
async def api_generate(request: Request): |
|
params = await request.json() |
|
worker = worker_map[params["model"]] |
|
await acquire_worker_semaphore(worker) |
|
output = worker.generate_gate(params) |
|
release_worker_semaphore(worker) |
|
return JSONResponse(output) |
|
|
|
|
|
@app.post("/worker_get_embeddings") |
|
async def api_get_embeddings(request: Request): |
|
params = await request.json() |
|
worker = worker_map[params["model"]] |
|
await acquire_worker_semaphore(worker) |
|
embedding = worker.get_embeddings(params) |
|
release_worker_semaphore(worker) |
|
return JSONResponse(content=embedding) |
|
|
|
|
|
@app.post("/worker_get_status") |
|
async def api_get_status(request: Request): |
|
return { |
|
"model_names": [m for w in workers for m in w.model_names], |
|
"speed": 1, |
|
"queue_length": sum([w.get_queue_length() for w in workers]), |
|
} |
|
|
|
|
|
@app.post("/count_token") |
|
async def api_count_token(request: Request): |
|
params = await request.json() |
|
worker = worker_map[params["model"]] |
|
return worker.count_token(params) |
|
|
|
|
|
@app.post("/worker_get_conv_template") |
|
async def api_get_conv(request: Request): |
|
params = await request.json() |
|
worker = worker_map[params["model"]] |
|
return worker.get_conv_template() |
|
|
|
|
|
@app.post("/model_details") |
|
async def api_model_details(request: Request): |
|
params = await request.json() |
|
worker = worker_map[params["model"]] |
|
return {"context_length": worker.context_len} |
|
|
|
|
|
def create_huggingface_api_worker(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", type=str, default="localhost") |
|
parser.add_argument("--port", type=int, default=21002) |
|
parser.add_argument("--worker-address", type=str, default="http://localhost:21002") |
|
parser.add_argument( |
|
"--controller-address", type=str, default="http://localhost:21001" |
|
) |
|
|
|
parser.add_argument( |
|
"--model-info-file", |
|
type=str, |
|
required=True, |
|
help="Huggingface API model's info file path", |
|
) |
|
|
|
parser.add_argument( |
|
"--limit-worker-concurrency", |
|
type=int, |
|
default=5, |
|
help="Limit the model concurrency to prevent OOM.", |
|
) |
|
parser.add_argument("--no-register", action="store_true") |
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=None, |
|
help="Overwrite the random seed for each generation.", |
|
) |
|
parser.add_argument( |
|
"--ssl", |
|
action="store_true", |
|
required=False, |
|
default=False, |
|
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", |
|
) |
|
args = parser.parse_args() |
|
|
|
with open(args.model_info_file, "r", encoding="UTF-8") as f: |
|
model_info = json.load(f) |
|
|
|
logger.info(f"args: {args}") |
|
|
|
model_path_list = [] |
|
api_base_list = [] |
|
token_list = [] |
|
context_length_list = [] |
|
model_names_list = [] |
|
conv_template_list = [] |
|
|
|
for m in model_info: |
|
model_path_list.append(model_info[m]["model_path"]) |
|
api_base_list.append(model_info[m]["api_base"]) |
|
token_list.append(model_info[m]["token"]) |
|
|
|
context_length = model_info[m]["context_length"] |
|
model_names = model_info[m].get("model_names", [m.split("/")[-1]]) |
|
if isinstance(model_names, str): |
|
model_names = [model_names] |
|
conv_template = model_info[m].get("conv_template", None) |
|
|
|
context_length_list.append(context_length) |
|
model_names_list.append(model_names) |
|
conv_template_list.append(conv_template) |
|
|
|
logger.info(f"Model paths: {model_path_list}") |
|
logger.info(f"API bases: {api_base_list}") |
|
logger.info(f"Tokens: {token_list}") |
|
logger.info(f"Context lengths: {context_length_list}") |
|
logger.info(f"Model names: {model_names_list}") |
|
logger.info(f"Conv templates: {conv_template_list}") |
|
|
|
for ( |
|
model_names, |
|
conv_template, |
|
model_path, |
|
api_base, |
|
token, |
|
context_length, |
|
) in zip( |
|
model_names_list, |
|
conv_template_list, |
|
model_path_list, |
|
api_base_list, |
|
token_list, |
|
context_length_list, |
|
): |
|
m = HuggingfaceApiWorker( |
|
args.controller_address, |
|
args.worker_address, |
|
worker_id, |
|
model_path, |
|
api_base, |
|
token, |
|
context_length, |
|
model_names, |
|
args.limit_worker_concurrency, |
|
no_register=args.no_register, |
|
conv_template=conv_template, |
|
seed=args.seed, |
|
) |
|
workers.append(m) |
|
for name in model_names: |
|
worker_map[name] = m |
|
|
|
|
|
url = args.controller_address + "/register_worker" |
|
data = { |
|
"worker_name": workers[0].worker_addr, |
|
"check_heart_beat": not args.no_register, |
|
"worker_status": { |
|
"model_names": [m for w in workers for m in w.model_names], |
|
"speed": 1, |
|
"queue_length": sum([w.get_queue_length() for w in workers]), |
|
}, |
|
} |
|
r = requests.post(url, json=data) |
|
assert r.status_code == 200 |
|
|
|
return args, workers |
|
|
|
|
|
if __name__ == "__main__": |
|
args, workers = create_huggingface_api_worker() |
|
if args.ssl: |
|
uvicorn.run( |
|
app, |
|
host=args.host, |
|
port=args.port, |
|
log_level="info", |
|
ssl_keyfile=os.environ["SSL_KEYFILE"], |
|
ssl_certfile=os.environ["SSL_CERTFILE"], |
|
) |
|
else: |
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info") |
|
|