Spaces:
Running
Running
File size: 2,583 Bytes
3a09006 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from utils.logger import logger
from networks.message_streamer import MessageStreamer
from messagers.message_composer import MessageComposer
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="HuggingFace LLM API",
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version="1.0",
)
self.setup_routes()
def get_available_models(self):
self.available_models = [
{
"id": "mixtral-8x7b",
"description": "Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
},
]
return self.available_models
class ChatCompletionsPostItem(BaseModel):
model: str = Field(
default="mixtral-8x7b",
description="(str) `mixtral-8x7b`",
)
messages: list = Field(
default=[{"role": "user", "content": "Hello, who are you?"}],
description="(list) Messages",
)
temperature: float = Field(
default=0.01,
description="(float) Temperature",
)
max_tokens: int = Field(
default=32000,
description="(int) Max tokens",
)
stream: bool = Field(
default=True,
description="(bool) Stream",
)
def chat_completions(self, item: ChatCompletionsPostItem):
streamer = MessageStreamer(model=item.model)
composer = MessageComposer(model=item.model)
composer.merge(messages=item.messages)
return EventSourceResponse(
streamer.chat(
prompt=composer.merged_str,
temperature=item.temperature,
max_new_tokens=item.max_tokens,
stream=item.stream,
yield_output=True,
),
media_type="text/event-stream",
)
def setup_routes(self):
for prefix in ["", "/v1"]:
self.app.get(
prefix + "/models",
summary="Get available models",
)(self.get_available_models)
self.app.post(
prefix + "/chat/completions",
summary="Chat completions in conversation session",
)(self.chat_completions)
app = ChatAPIApp().app
if __name__ == "__main__":
uvicorn.run("__main__:app", host="0.0.0.0", port=23333, reload=True)
|