xinference / app.py
aresnow
Arena as fisrt tab
a9bd2fc
raw
history blame
17 kB
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import urllib.request
import uuid
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import gradio as gr
from xinference.locale.utils import Locale
from xinference.model import MODEL_FAMILIES, ModelSpec
from xinference.core.api import SyncSupervisorAPI
if TYPE_CHECKING:
from xinference.types import ChatCompletionChunk, ChatCompletionMessage
MODEL_TO_FAMILIES = dict(
(model_family.model_name, model_family)
for model_family in MODEL_FAMILIES
if model_family.model_name != "baichuan"
)
class GradioApp:
def __init__(
self,
supervisor_address: str,
gladiator_num: int = 2,
max_model_num: int = 2,
use_launched_model: bool = False,
):
self._api = SyncSupervisorAPI(supervisor_address)
self._gladiator_num = gladiator_num
self._max_model_num = max_model_num
self._use_launched_model = use_launched_model
self._locale = Locale()
def _create_model(
self,
model_name: str,
model_size_in_billions: Optional[int] = None,
model_format: Optional[str] = None,
quantization: Optional[str] = None,
):
model_uid = str(uuid.uuid1())
models = self._api.list_models()
if len(models) >= self._max_model_num:
self._api.terminate_model(models[0][0])
return self._api.launch_model(
model_uid, model_name, model_size_in_billions, model_format, quantization
)
async def generate(
self,
model: str,
message: str,
chat: List[List[str]],
max_token: int,
temperature: float,
top_p: float,
window_size: int,
show_finish_reason: bool,
):
if not message:
yield message, chat
else:
try:
model_ref = self._api.get_model(model)
except KeyError:
raise gr.Error(self._locale(f"Please create model first"))
history: "List[ChatCompletionMessage]" = []
for c in chat:
history.append({"role": "user", "content": c[0]})
out = c[1]
finish_reason_idx = out.find(f"[{self._locale('stop reason')}: ")
if finish_reason_idx != -1:
out = out[:finish_reason_idx]
history.append({"role": "assistant", "content": out})
if window_size != 0:
history = history[-(window_size // 2) :]
# chatglm only support even number of conversation history.
if len(history) % 2 != 0:
history = history[1:]
generate_config = dict(
max_tokens=max_token,
temperature=temperature,
top_p=top_p,
stream=True,
)
chat += [[message, ""]]
chat_generator = await model_ref.chat(
message,
chat_history=history,
generate_config=generate_config,
)
chunk: Optional["ChatCompletionChunk"] = None
async for chunk in chat_generator:
assert chunk is not None
delta = chunk["choices"][0]["delta"]
if "content" not in delta:
continue
else:
chat[-1][1] += delta["content"]
yield "", chat
if show_finish_reason and chunk is not None:
chat[-1][
1
] += f"[{self._locale('stop reason')}: {chunk['choices'][0]['finish_reason']}]"
yield "", chat
def _build_chatbot(self, model_uid: str, model_name: str):
with gr.Accordion(self._locale("Parameters"), open=False):
max_token = gr.Slider(
128,
1024,
value=256,
step=1,
label=self._locale("Max tokens"),
info=self._locale("The maximum number of tokens to generate."),
)
temperature = gr.Slider(
0.2,
1,
value=0.8,
step=0.01,
label=self._locale("Temperature"),
info=self._locale("The temperature to use for sampling."),
)
top_p = gr.Slider(
0.2,
1,
value=0.95,
step=0.01,
label=self._locale("Top P"),
info=self._locale("The top-p value to use for sampling."),
)
window_size = gr.Slider(
0,
50,
value=10,
step=1,
label=self._locale("Window size"),
info=self._locale("Window size of chat history."),
)
show_finish_reason = gr.Checkbox(
label=f"{self._locale('Show stop reason')}"
)
chat = gr.Chatbot(label=model_name)
text = gr.Textbox(visible=False)
model_uid = gr.Textbox(model_uid, visible=False)
text.change(
self.generate,
[
model_uid,
text,
chat,
max_token,
temperature,
top_p,
window_size,
show_finish_reason,
],
[text, chat],
)
return (
text,
chat,
max_token,
temperature,
top_p,
show_finish_reason,
window_size,
model_uid,
)
def _build_chat_column(self):
with gr.Column():
with gr.Row():
model_name = gr.Dropdown(
choices=list(MODEL_TO_FAMILIES.keys()),
label=self._locale("model name"),
scale=2,
)
model_format = gr.Dropdown(
choices=[],
interactive=False,
label=self._locale("model format"),
scale=2,
)
model_size_in_billions = gr.Dropdown(
choices=[],
interactive=False,
label=self._locale("model size in billions"),
scale=1,
)
quantization = gr.Dropdown(
choices=[],
interactive=False,
label=self._locale("quantization"),
scale=1,
)
create_model = gr.Button(value=self._locale("create"))
def select_model_name(model_name: str):
if model_name:
model_family = MODEL_TO_FAMILIES[model_name]
formats = [model_family.model_format]
model_sizes_in_billions = [
str(b) for b in model_family.model_sizes_in_billions
]
quantizations = model_family.quantizations
return (
gr.Dropdown.update(
choices=formats,
interactive=True,
value=model_family.model_format,
),
gr.Dropdown.update(
choices=model_sizes_in_billions[:1],
interactive=True,
value=model_sizes_in_billions[0],
),
gr.Dropdown.update(
choices=quantizations,
interactive=True,
value=quantizations[0],
),
)
else:
return (
gr.Dropdown.update(),
gr.Dropdown.update(),
gr.Dropdown.update(),
)
model_name.change(
select_model_name,
inputs=[model_name],
outputs=[model_format, model_size_in_billions, quantization],
)
components = self._build_chatbot("", "")
model_text = components[0]
chat, model_uid = components[1], components[-1]
def select_model(
_model_name: str,
_model_format: str,
_model_size_in_billions: str,
_quantization: str,
progress=gr.Progress(),
):
model_family = MODEL_TO_FAMILIES[_model_name]
cache_path, meta_path = model_family.generate_cache_path(
int(_model_size_in_billions), _quantization
)
if not (os.path.exists(cache_path) and os.path.exists(meta_path)):
if os.path.exists(cache_path):
os.remove(cache_path)
url = model_family.url_generator(
int(_model_size_in_billions), _quantization
)
full_name = (
f"{str(model_family)}-{_model_size_in_billions}b-{_quantization}"
)
try:
urllib.request.urlretrieve(
url,
cache_path,
reporthook=lambda block_num, block_size, total_size: progress(
block_num * block_size / total_size,
desc=self._locale("Downloading"),
),
)
# write a meta file to record if download finished
with open(meta_path, "w") as f:
f.write(full_name)
except:
if os.path.exists(cache_path):
os.remove(cache_path)
model_uid = self._create_model(
_model_name, int(_model_size_in_billions), _model_format, _quantization
)
return gr.Chatbot.update(
label="-".join(
[_model_name, _model_size_in_billions, _model_format, _quantization]
),
value=[],
), gr.Textbox.update(value=model_uid)
def clear_chat(
_model_name: str,
_model_format: str,
_model_size_in_billions: str,
_quantization: str,
):
full_name = "-".join(
[_model_name, _model_size_in_billions, _model_format, _quantization]
)
return str(uuid.uuid4()), gr.Chatbot.update(
label=full_name,
value=[],
)
invisible_text = gr.Textbox(visible=False)
create_model.click(
clear_chat,
inputs=[model_name, model_format, model_size_in_billions, quantization],
outputs=[invisible_text, chat],
)
invisible_text.change(
select_model,
inputs=[model_name, model_format, model_size_in_billions, quantization],
outputs=[chat, model_uid],
postprocess=False,
)
return chat, model_text
def _build_arena(self):
with gr.Box():
with gr.Row():
chat_and_text = [
self._build_chat_column() for _ in range(self._gladiator_num)
]
chats = [c[0] for c in chat_and_text]
texts = [c[1] for c in chat_and_text]
msg = gr.Textbox(label=self._locale("Input"))
def update_message(text_in: str):
return "", text_in, text_in
msg.submit(update_message, inputs=[msg], outputs=[msg] + texts)
gr.ClearButton(components=[msg] + chats + texts)
def _build_single(self):
chat, model_text = self._build_chat_column()
msg = gr.Textbox(label=self._locale("Input"))
def update_message(text_in: str):
return "", text_in
msg.submit(update_message, inputs=[msg], outputs=[msg, model_text])
gr.ClearButton(components=[chat, msg, model_text])
def _build_single_with_launched(
self, models: List[Tuple[str, ModelSpec]], default_index: int
):
uid_to_model_spec: Dict[str, ModelSpec] = dict((m[0], m[1]) for m in models)
choices = [
"-".join(
[
s.model_name,
str(s.model_size_in_billions),
s.model_format,
s.quantization,
]
)
for s in uid_to_model_spec.values()
]
choice_to_uid = dict(zip(choices, uid_to_model_spec.keys()))
model_selection = gr.Dropdown(
label=self._locale("select model"),
choices=choices,
value=choices[default_index],
)
components = self._build_chatbot(
models[default_index][0], choices[default_index]
)
model_text = components[0]
model_uid = components[-1]
chat = components[1]
def select_model(model_name):
uid = choice_to_uid[model_name]
return gr.Chatbot.update(label=model_name), uid
model_selection.change(
select_model, inputs=[model_selection], outputs=[chat, model_uid]
)
return chat, model_text
def _build_arena_with_launched(self, models: List[Tuple[str, ModelSpec]]):
chat_and_text = []
with gr.Row():
for i in range(self._gladiator_num):
with gr.Column():
chat_and_text.append(self._build_single_with_launched(models, i))
chats = [c[0] for c in chat_and_text]
texts = [c[1] for c in chat_and_text]
msg = gr.Textbox(label=self._locale("Input"))
def update_message(text_in: str):
return "", text_in, text_in
msg.submit(update_message, inputs=[msg], outputs=[msg] + texts)
gr.ClearButton(components=[msg] + chats + texts)
def build(self):
if self._use_launched_model:
models = self._api.list_models()
with gr.Blocks() as blocks:
if len(models) >= 2:
with gr.Tab(self._locale("Arena")):
self._build_arena_with_launched(models)
with gr.Tab(self._locale("Chat")):
chat, model_text = self._build_single_with_launched(models, 0)
msg = gr.Textbox(label=self._locale("Input"))
def update_message(text_in: str):
return "", text_in
msg.submit(update_message, inputs=[msg], outputs=[msg, model_text])
gr.ClearButton(components=[chat, msg, model_text])
else:
with gr.Blocks() as blocks:
with gr.Tab(self._locale("Chat")):
self._build_single()
with gr.Tab(self._locale("Arena")):
self._build_arena()
blocks.queue(concurrency_count=40)
return blocks
async def launch_xinference():
import xoscar as xo
from xinference.core.service import SupervisorActor
from xinference.core.api import AsyncSupervisorAPI
from xinference.deploy.worker import start_worker_components
pool = await xo.create_actor_pool(address="0.0.0.0", n_process=0)
supervisor_address = pool.external_address
await xo.create_actor(
SupervisorActor, address=supervisor_address, uid=SupervisorActor.uid()
)
await start_worker_components(
address=supervisor_address, supervisor_address=supervisor_address
)
api = AsyncSupervisorAPI(supervisor_address)
supported_models = ["orca", "chatglm2", "chatglm", "vicuna-v1.3"]
for model in supported_models:
await api.launch_model(str(uuid.uuid4()), model)
gradio_block = GradioApp(supervisor_address, use_launched_model=True).build()
gradio_block.launch()
if __name__ == "__main__":
loop = asyncio.get_event_loop()
task = loop.create_task(launch_xinference())
try:
loop.run_until_complete(task)
except KeyboardInterrupt:
task.cancel()
loop.run_until_complete(task)
# avoid displaying exception-unhandled warnings
task.exception()