|
""" |
|
Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" |
|
|
|
Workers are listed in format of `model-path`@`host`@`port` |
|
|
|
The key mechanism behind this scripts is: |
|
1, execute shell cmd to launch the controller/worker/openai-api-server; |
|
2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly. |
|
Note that a few of non-critical `fastchat.serve` cmd options are not supported currently. |
|
""" |
|
import sys |
|
import os |
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
|
|
import subprocess |
|
import re |
|
import argparse |
|
|
|
LOGDIR = "./logs/" |
|
|
|
if not os.path.exists(LOGDIR): |
|
os.makedirs(LOGDIR) |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model-path-address", |
|
default="THUDM/chatglm2-6b@localhost@20002", |
|
nargs="+", |
|
type=str, |
|
help="model path, host, and port, formatted as model-path@host@port", |
|
) |
|
|
|
|
|
parser.add_argument("--controller-host", type=str, default="localhost") |
|
parser.add_argument("--controller-port", type=int, default=21001) |
|
parser.add_argument( |
|
"--dispatch-method", |
|
type=str, |
|
choices=["lottery", "shortest_queue"], |
|
default="shortest_queue", |
|
) |
|
controller_args = ["controller-host", "controller-port", "dispatch-method"] |
|
|
|
|
|
|
|
parser.add_argument("--worker-host", type=str, default="localhost") |
|
parser.add_argument("--worker-port", type=int, default=21002) |
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
"--model-path", |
|
type=str, |
|
default="lmsys/vicuna-7b-v1.5", |
|
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", |
|
) |
|
parser.add_argument( |
|
"--revision", |
|
type=str, |
|
default="main", |
|
help="Hugging Face Hub model revision identifier", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
type=str, |
|
choices=["cpu", "cuda", "mps", "xpu", "npu"], |
|
default="cuda", |
|
help="The device type", |
|
) |
|
parser.add_argument( |
|
"--gpus", |
|
type=str, |
|
default="0", |
|
help="A single GPU like 1 or multiple GPUs like 0,2", |
|
) |
|
parser.add_argument("--num-gpus", type=int, default=1) |
|
parser.add_argument( |
|
"--max-gpu-memory", |
|
type=str, |
|
help="The maximum memory per gpu. Use a string like '13Gib'", |
|
) |
|
parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") |
|
parser.add_argument( |
|
"--cpu-offloading", |
|
action="store_true", |
|
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", |
|
) |
|
parser.add_argument( |
|
"--gptq-ckpt", |
|
type=str, |
|
default=None, |
|
help="Load quantized model. The path to the local GPTQ checkpoint.", |
|
) |
|
parser.add_argument( |
|
"--gptq-wbits", |
|
type=int, |
|
default=16, |
|
choices=[2, 3, 4, 8, 16], |
|
help="#bits to use for quantization", |
|
) |
|
parser.add_argument( |
|
"--gptq-groupsize", |
|
type=int, |
|
default=-1, |
|
help="Groupsize to use for quantization; default uses full row.", |
|
) |
|
parser.add_argument( |
|
"--gptq-act-order", |
|
action="store_true", |
|
help="Whether to apply the activation order GPTQ heuristic", |
|
) |
|
parser.add_argument( |
|
"--model-names", |
|
type=lambda s: s.split(","), |
|
help="Optional display comma separated names", |
|
) |
|
parser.add_argument( |
|
"--limit-worker-concurrency", |
|
type=int, |
|
default=5, |
|
help="Limit the model concurrency to prevent OOM.", |
|
) |
|
parser.add_argument("--stream-interval", type=int, default=2) |
|
parser.add_argument("--no-register", action="store_true") |
|
|
|
worker_args = [ |
|
"worker-host", |
|
"worker-port", |
|
"model-path", |
|
"revision", |
|
"device", |
|
"gpus", |
|
"num-gpus", |
|
"max-gpu-memory", |
|
"load-8bit", |
|
"cpu-offloading", |
|
"gptq-ckpt", |
|
"gptq-wbits", |
|
"gptq-groupsize", |
|
"gptq-act-order", |
|
"model-names", |
|
"limit-worker-concurrency", |
|
"stream-interval", |
|
"no-register", |
|
"controller-address", |
|
] |
|
|
|
|
|
parser.add_argument("--server-host", type=str, default="localhost", help="host name") |
|
parser.add_argument("--server-port", type=int, default=8001, help="port number") |
|
parser.add_argument( |
|
"--allow-credentials", action="store_true", help="allow credentials" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
"--api-keys", |
|
type=lambda s: s.split(","), |
|
help="Optional list of comma separated API keys", |
|
) |
|
server_args = [ |
|
"server-host", |
|
"server-port", |
|
"allow-credentials", |
|
"api-keys", |
|
"controller-address", |
|
] |
|
|
|
args = parser.parse_args() |
|
|
|
args = argparse.Namespace( |
|
**vars(args), |
|
**{"controller-address": f"http://{args.controller_host}:{args.controller_port}"}, |
|
) |
|
|
|
if args.gpus: |
|
if len(args.gpus.split(",")) < args.num_gpus: |
|
raise ValueError( |
|
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" |
|
) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus |
|
|
|
|
|
|
|
|
|
|
|
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" |
|
|
|
|
|
|
|
|
|
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do |
|
sleep 1s; |
|
echo "wait {2} running" |
|
done |
|
echo '{2} running' """ |
|
|
|
|
|
def string_args(args, args_list): |
|
args_str = "" |
|
for key, value in args._get_kwargs(): |
|
key = key.replace("_", "-") |
|
if key not in args_list: |
|
continue |
|
|
|
key = key.split("-")[-1] if re.search("port|host", key) else key |
|
if not value: |
|
pass |
|
|
|
elif isinstance(value, bool) and value == True: |
|
args_str += f" --{key} " |
|
elif ( |
|
isinstance(value, list) |
|
or isinstance(value, tuple) |
|
or isinstance(value, set) |
|
): |
|
value = " ".join(value) |
|
args_str += f" --{key} {value} " |
|
else: |
|
args_str += f" --{key} {value} " |
|
|
|
return args_str |
|
|
|
|
|
def launch_worker(item): |
|
log_name = ( |
|
item.split("/")[-1] |
|
.split("\\")[-1] |
|
.replace("-", "_") |
|
.replace("@", "_") |
|
.replace(".", "_") |
|
) |
|
|
|
args.model_path, args.worker_host, args.worker_port = item.split("@") |
|
print("*" * 80) |
|
worker_str_args = string_args(args, worker_args) |
|
print(worker_str_args) |
|
worker_sh = base_launch_sh.format( |
|
"model_worker", worker_str_args, LOGDIR, f"worker_{log_name}" |
|
) |
|
worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker") |
|
subprocess.run(worker_sh, shell=True, check=True) |
|
subprocess.run(worker_check_sh, shell=True, check=True) |
|
|
|
|
|
def launch_all(): |
|
controller_str_args = string_args(args, controller_args) |
|
controller_sh = base_launch_sh.format( |
|
"controller", controller_str_args, LOGDIR, "controller" |
|
) |
|
controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller") |
|
subprocess.run(controller_sh, shell=True, check=True) |
|
subprocess.run(controller_check_sh, shell=True, check=True) |
|
|
|
if isinstance(args.model_path_address, str): |
|
launch_worker(args.model_path_address) |
|
else: |
|
for idx, item in enumerate(args.model_path_address): |
|
print(f"loading {idx}th model:{item}") |
|
launch_worker(item) |
|
|
|
server_str_args = string_args(args, server_args) |
|
server_sh = base_launch_sh.format( |
|
"openai_api_server", server_str_args, LOGDIR, "openai_api_server" |
|
) |
|
server_check_sh = base_check_sh.format( |
|
LOGDIR, "openai_api_server", "openai_api_server" |
|
) |
|
subprocess.run(server_sh, shell=True, check=True) |
|
subprocess.run(server_check_sh, shell=True, check=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
launch_all() |
|
|