Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- README.md +3 -9
- cli.py +139 -0
- controller.py +298 -0
- examples/1034346401.mp4 +3 -0
- examples/desert.jpg +0 -0
- examples/extreme_ironing.jpg +0 -0
- examples/sample_demo_1.mp4 +3 -0
- examples/sample_demo_3.mp4 +0 -0
- examples/sample_demo_9.mp4 +0 -0
- examples/waterview.jpg +0 -0
- gradio_web_server.py +499 -0
- gradio_web_server_adhoc.py +318 -0
- model_worker.py +397 -0
- register_worker.py +26 -0
- sglang_worker.py +244 -0
- test_message.py +62 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/1034346401.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/sample_demo_1.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: serve
|
3 |
+
app_file: gradio_web_server_adhoc.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.50.0
|
|
|
|
|
6 |
---
|
|
|
|
cli.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from videollama2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, NUM_FRAMES
|
5 |
+
from videollama2.conversation import conv_templates, SeparatorStyle
|
6 |
+
from videollama2.model.builder import load_pretrained_model
|
7 |
+
from videollama2.utils import disable_torch_init
|
8 |
+
from videollama2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, tokenizer_MMODAL_token
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
|
13 |
+
import requests
|
14 |
+
from io import BytesIO
|
15 |
+
from transformers import TextStreamer
|
16 |
+
|
17 |
+
|
18 |
+
def load_image(image_file):
|
19 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
20 |
+
response = requests.get(image_file)
|
21 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
22 |
+
else:
|
23 |
+
image = Image.open(image_file).convert('RGB')
|
24 |
+
return image
|
25 |
+
|
26 |
+
def load_video(video_file):
|
27 |
+
decord_vr = VideoReader(uri=video_file, ctx=cpu(0))
|
28 |
+
duration = len(decord_vr)
|
29 |
+
frame_id_list = np.linspace(0, duration-1, NUM_FRAMES, dtype=int)
|
30 |
+
video = decord_vr.get_batch(frame_id_list)
|
31 |
+
return video
|
32 |
+
|
33 |
+
def load_image_or_video(image_or_video_file):
|
34 |
+
if file_path.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
|
35 |
+
return load_image(image_file=image_or_video_file)
|
36 |
+
elif file_path.endswith(('.mp4', '.avi', '.mov')):
|
37 |
+
return load_video(video_file=image_or_video_file)
|
38 |
+
else:
|
39 |
+
raise Exception(f"File type of {image_or_video_file} not supported!!!")
|
40 |
+
|
41 |
+
|
42 |
+
def main(args):
|
43 |
+
# Model
|
44 |
+
disable_torch_init()
|
45 |
+
|
46 |
+
model_name = get_model_name_from_path(args.model_path)
|
47 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
48 |
+
|
49 |
+
# if "llama-2" in model_name.lower():
|
50 |
+
# conv_mode = "llava_llama2"
|
51 |
+
# elif "mistral" in model_name.lower():
|
52 |
+
# conv_mode = "mistral"
|
53 |
+
# elif "v1.6-34b" in model_name.lower():
|
54 |
+
# conv_mode = "chatml_direct"
|
55 |
+
# elif "v1" in model_name.lower():
|
56 |
+
# conv_mode = "llava_v1"
|
57 |
+
# else:
|
58 |
+
# conv_mode = "llava_v0"
|
59 |
+
conv_mode = "llava_v1" # fix conversation mode for now
|
60 |
+
|
61 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
62 |
+
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
63 |
+
else:
|
64 |
+
args.conv_mode = conv_mode
|
65 |
+
|
66 |
+
conv = conv_templates[args.conv_mode].copy()
|
67 |
+
roles = conv.roles
|
68 |
+
|
69 |
+
image = load_image(args.image_file)
|
70 |
+
image_size = image.size
|
71 |
+
# Similar operation in model_worker.py
|
72 |
+
image_tensor = process_images([image], image_processor, model.config)
|
73 |
+
if type(image_tensor) is list:
|
74 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
75 |
+
else:
|
76 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
77 |
+
|
78 |
+
while True:
|
79 |
+
try:
|
80 |
+
inp = input(f"{roles[0]}: ")
|
81 |
+
except EOFError:
|
82 |
+
inp = ""
|
83 |
+
if not inp:
|
84 |
+
print("exit...")
|
85 |
+
break
|
86 |
+
|
87 |
+
print(f"{roles[1]}: ", end="")
|
88 |
+
|
89 |
+
if image is not None:
|
90 |
+
# first message
|
91 |
+
if model.config.mm_use_im_start_end:
|
92 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
93 |
+
else:
|
94 |
+
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
95 |
+
conv.append_message(conv.roles[0], inp)
|
96 |
+
image = None
|
97 |
+
else:
|
98 |
+
# later messages
|
99 |
+
conv.append_message(conv.roles[0], inp)
|
100 |
+
conv.append_message(conv.roles[1], None)
|
101 |
+
prompt = conv.get_prompt()
|
102 |
+
|
103 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
104 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
105 |
+
keywords = [stop_str]
|
106 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
107 |
+
|
108 |
+
with torch.inference_mode():
|
109 |
+
output_ids = model.generate(
|
110 |
+
input_ids,
|
111 |
+
images=image_tensor,
|
112 |
+
image_sizes=[image_size],
|
113 |
+
do_sample=True if args.temperature > 0 else False,
|
114 |
+
temperature=args.temperature,
|
115 |
+
max_new_tokens=args.max_new_tokens,
|
116 |
+
streamer=streamer,
|
117 |
+
use_cache=True)
|
118 |
+
|
119 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
120 |
+
conv.messages[-1][-1] = outputs
|
121 |
+
|
122 |
+
if args.debug:
|
123 |
+
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
129 |
+
parser.add_argument("--model-base", type=str, default=None)
|
130 |
+
parser.add_argument("--image-file", type=str, required=True)
|
131 |
+
parser.add_argument("--device", type=str, default="cuda")
|
132 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
133 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
134 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
135 |
+
parser.add_argument("--load-8bit", action="store_true")
|
136 |
+
parser.add_argument("--load-4bit", action="store_true")
|
137 |
+
parser.add_argument("--debug", action="store_true")
|
138 |
+
args = parser.parse_args()
|
139 |
+
main(args)
|
controller.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import time
|
12 |
+
from typing import List, Union
|
13 |
+
import threading
|
14 |
+
|
15 |
+
from fastapi import FastAPI, Request
|
16 |
+
from fastapi.responses import StreamingResponse
|
17 |
+
import numpy as np
|
18 |
+
import requests
|
19 |
+
import uvicorn
|
20 |
+
|
21 |
+
from videollama2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
+
from videollama2.utils import build_logger, server_error_msg
|
23 |
+
|
24 |
+
|
25 |
+
logger = build_logger("controller", "controller.log")
|
26 |
+
|
27 |
+
|
28 |
+
class DispatchMethod(Enum):
|
29 |
+
LOTTERY = auto()
|
30 |
+
SHORTEST_QUEUE = auto()
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_str(cls, name):
|
34 |
+
if name == "lottery":
|
35 |
+
return cls.LOTTERY
|
36 |
+
elif name == "shortest_queue":
|
37 |
+
return cls.SHORTEST_QUEUE
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Invalid dispatch method")
|
40 |
+
|
41 |
+
|
42 |
+
@dataclasses.dataclass
|
43 |
+
class WorkerInfo:
|
44 |
+
model_names: List[str]
|
45 |
+
speed: int
|
46 |
+
queue_length: int
|
47 |
+
check_heart_beat: bool
|
48 |
+
last_heart_beat: str
|
49 |
+
|
50 |
+
|
51 |
+
def heart_beat_controller(controller):
|
52 |
+
while True:
|
53 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
54 |
+
controller.remove_stable_workers_by_expiration()
|
55 |
+
|
56 |
+
|
57 |
+
class Controller:
|
58 |
+
def __init__(self, dispatch_method: str):
|
59 |
+
# Dict[str -> WorkerInfo]
|
60 |
+
self.worker_info = {}
|
61 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
62 |
+
|
63 |
+
self.heart_beat_thread = threading.Thread(
|
64 |
+
target=heart_beat_controller, args=(self,), daemon=True)
|
65 |
+
self.heart_beat_thread.start()
|
66 |
+
|
67 |
+
logger.info("Init controller")
|
68 |
+
|
69 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
70 |
+
worker_status: dict):
|
71 |
+
if worker_name not in self.worker_info:
|
72 |
+
logger.info(f"Register a new worker: {worker_name}")
|
73 |
+
else:
|
74 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
75 |
+
|
76 |
+
if not worker_status:
|
77 |
+
worker_status = self.get_worker_status(worker_name)
|
78 |
+
if not worker_status:
|
79 |
+
return False
|
80 |
+
|
81 |
+
self.worker_info[worker_name] = WorkerInfo(
|
82 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
83 |
+
check_heart_beat, time.time())
|
84 |
+
|
85 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
+
return True
|
87 |
+
|
88 |
+
def get_worker_status(self, worker_name: str):
|
89 |
+
try:
|
90 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
91 |
+
except requests.exceptions.RequestException as e:
|
92 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
93 |
+
return None
|
94 |
+
|
95 |
+
if r.status_code != 200:
|
96 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
97 |
+
return None
|
98 |
+
|
99 |
+
return r.json()
|
100 |
+
|
101 |
+
def remove_worker(self, worker_name: str):
|
102 |
+
del self.worker_info[worker_name]
|
103 |
+
|
104 |
+
def refresh_all_workers(self):
|
105 |
+
old_info = dict(self.worker_info)
|
106 |
+
self.worker_info = {}
|
107 |
+
|
108 |
+
for w_name, w_info in old_info.items():
|
109 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
110 |
+
logger.info(f"Remove stale worker: {w_name}")
|
111 |
+
|
112 |
+
def list_models(self):
|
113 |
+
model_names = set()
|
114 |
+
|
115 |
+
for w_name, w_info in self.worker_info.items():
|
116 |
+
model_names.update(w_info.model_names)
|
117 |
+
|
118 |
+
return list(model_names)
|
119 |
+
|
120 |
+
def get_worker_address(self, model_name: str):
|
121 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
122 |
+
worker_names = []
|
123 |
+
worker_speeds = []
|
124 |
+
for w_name, w_info in self.worker_info.items():
|
125 |
+
if model_name in w_info.model_names:
|
126 |
+
worker_names.append(w_name)
|
127 |
+
worker_speeds.append(w_info.speed)
|
128 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
129 |
+
norm = np.sum(worker_speeds)
|
130 |
+
if norm < 1e-4:
|
131 |
+
return ""
|
132 |
+
worker_speeds = worker_speeds / norm
|
133 |
+
if True: # Directly return address
|
134 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
135 |
+
p=worker_speeds)
|
136 |
+
worker_name = worker_names[pt]
|
137 |
+
return worker_name
|
138 |
+
|
139 |
+
# Check status before returning
|
140 |
+
while True:
|
141 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
142 |
+
p=worker_speeds)
|
143 |
+
worker_name = worker_names[pt]
|
144 |
+
|
145 |
+
if self.get_worker_status(worker_name):
|
146 |
+
break
|
147 |
+
else:
|
148 |
+
self.remove_worker(worker_name)
|
149 |
+
worker_speeds[pt] = 0
|
150 |
+
norm = np.sum(worker_speeds)
|
151 |
+
if norm < 1e-4:
|
152 |
+
return ""
|
153 |
+
worker_speeds = worker_speeds / norm
|
154 |
+
continue
|
155 |
+
return worker_name
|
156 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
157 |
+
worker_names = []
|
158 |
+
worker_qlen = []
|
159 |
+
for w_name, w_info in self.worker_info.items():
|
160 |
+
if model_name in w_info.model_names:
|
161 |
+
worker_names.append(w_name)
|
162 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
163 |
+
if len(worker_names) == 0:
|
164 |
+
return ""
|
165 |
+
min_index = np.argmin(worker_qlen)
|
166 |
+
w_name = worker_names[min_index]
|
167 |
+
self.worker_info[w_name].queue_length += 1
|
168 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
169 |
+
return w_name
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
172 |
+
|
173 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
174 |
+
if worker_name not in self.worker_info:
|
175 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
176 |
+
return False
|
177 |
+
|
178 |
+
self.worker_info[worker_name].queue_length = queue_length
|
179 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
180 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
181 |
+
return True
|
182 |
+
|
183 |
+
def remove_stable_workers_by_expiration(self):
|
184 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
185 |
+
to_delete = []
|
186 |
+
for worker_name, w_info in self.worker_info.items():
|
187 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
188 |
+
to_delete.append(worker_name)
|
189 |
+
|
190 |
+
for worker_name in to_delete:
|
191 |
+
self.remove_worker(worker_name)
|
192 |
+
|
193 |
+
def worker_api_generate_stream(self, params):
|
194 |
+
worker_addr = self.get_worker_address(params["model"])
|
195 |
+
if not worker_addr:
|
196 |
+
logger.info(f"no worker: {params['model']}")
|
197 |
+
ret = {
|
198 |
+
"text": server_error_msg,
|
199 |
+
"error_code": 2,
|
200 |
+
}
|
201 |
+
yield json.dumps(ret).encode() + b"\0"
|
202 |
+
|
203 |
+
try:
|
204 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
205 |
+
json=params, stream=True, timeout=5)
|
206 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
207 |
+
if chunk:
|
208 |
+
yield chunk + b"\0"
|
209 |
+
except requests.exceptions.RequestException as e:
|
210 |
+
logger.info(f"worker timeout: {worker_addr}")
|
211 |
+
ret = {
|
212 |
+
"text": server_error_msg,
|
213 |
+
"error_code": 3,
|
214 |
+
}
|
215 |
+
yield json.dumps(ret).encode() + b"\0"
|
216 |
+
|
217 |
+
|
218 |
+
# Let the controller act as a worker to achieve hierarchical
|
219 |
+
# management. This can be used to connect isolated sub networks.
|
220 |
+
def worker_api_get_status(self):
|
221 |
+
model_names = set()
|
222 |
+
speed = 0
|
223 |
+
queue_length = 0
|
224 |
+
|
225 |
+
for w_name in self.worker_info:
|
226 |
+
worker_status = self.get_worker_status(w_name)
|
227 |
+
if worker_status is not None:
|
228 |
+
model_names.update(worker_status["model_names"])
|
229 |
+
speed += worker_status["speed"]
|
230 |
+
queue_length += worker_status["queue_length"]
|
231 |
+
|
232 |
+
return {
|
233 |
+
"model_names": list(model_names),
|
234 |
+
"speed": speed,
|
235 |
+
"queue_length": queue_length,
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
app = FastAPI()
|
240 |
+
|
241 |
+
|
242 |
+
@app.post("/register_worker")
|
243 |
+
async def register_worker(request: Request):
|
244 |
+
data = await request.json()
|
245 |
+
controller.register_worker(
|
246 |
+
data["worker_name"], data["check_heart_beat"],
|
247 |
+
data.get("worker_status", None))
|
248 |
+
|
249 |
+
|
250 |
+
@app.post("/refresh_all_workers")
|
251 |
+
async def refresh_all_workers():
|
252 |
+
models = controller.refresh_all_workers()
|
253 |
+
|
254 |
+
|
255 |
+
@app.post("/list_models")
|
256 |
+
async def list_models():
|
257 |
+
models = controller.list_models()
|
258 |
+
return {"models": models}
|
259 |
+
|
260 |
+
|
261 |
+
@app.post("/get_worker_address")
|
262 |
+
async def get_worker_address(request: Request):
|
263 |
+
data = await request.json()
|
264 |
+
addr = controller.get_worker_address(data["model"])
|
265 |
+
return {"address": addr}
|
266 |
+
|
267 |
+
|
268 |
+
@app.post("/receive_heart_beat")
|
269 |
+
async def receive_heart_beat(request: Request):
|
270 |
+
data = await request.json()
|
271 |
+
exist = controller.receive_heart_beat(
|
272 |
+
data["worker_name"], data["queue_length"])
|
273 |
+
return {"exist": exist}
|
274 |
+
|
275 |
+
|
276 |
+
@app.post("/worker_generate_stream")
|
277 |
+
async def worker_api_generate_stream(request: Request):
|
278 |
+
params = await request.json()
|
279 |
+
generator = controller.worker_api_generate_stream(params)
|
280 |
+
return StreamingResponse(generator)
|
281 |
+
|
282 |
+
|
283 |
+
@app.post("/worker_get_status")
|
284 |
+
async def worker_api_get_status(request: Request):
|
285 |
+
return controller.worker_api_get_status()
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
parser = argparse.ArgumentParser()
|
290 |
+
parser.add_argument("--host", type=str, default="localhost")
|
291 |
+
parser.add_argument("--port", type=int, default=21001)
|
292 |
+
parser.add_argument("--dispatch-method", type=str, choices=[
|
293 |
+
"lottery", "shortest_queue"], default="shortest_queue")
|
294 |
+
args = parser.parse_args()
|
295 |
+
logger.info(f"args: {args}")
|
296 |
+
|
297 |
+
controller = Controller(args.dispatch_method)
|
298 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
examples/1034346401.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:08b62a634fe49edc0a19fc53f6ea5cfb345d9b2a6a7047811344c16832dc42b2
|
3 |
+
size 1678095
|
examples/desert.jpg
ADDED
examples/extreme_ironing.jpg
ADDED
examples/sample_demo_1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc6562a172eb9cb3c760a3c9992349c1faa2c793c112b7b9e50bd5cb17c2164d
|
3 |
+
size 1549315
|
examples/sample_demo_3.mp4
ADDED
Binary file (464 kB). View file
|
|
examples/sample_demo_9.mp4
ADDED
Binary file (632 kB). View file
|
|
examples/waterview.jpg
ADDED
gradio_web_server.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import hashlib
|
5 |
+
import requests
|
6 |
+
import argparse
|
7 |
+
import datetime
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import gradio as gr
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
|
13 |
+
from videollama2.constants import LOGDIR, NUM_FRAMES
|
14 |
+
from videollama2.conversation import (default_conversation, conv_templates,SeparatorStyle)
|
15 |
+
from videollama2.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
|
16 |
+
|
17 |
+
|
18 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
19 |
+
|
20 |
+
headers = {"User-Agent": "Videollama2 Client"}
|
21 |
+
|
22 |
+
no_change_btn = gr.Button.update()
|
23 |
+
enable_btn = gr.Button.update(interactive=True)
|
24 |
+
disable_btn = gr.Button.update(interactive=False)
|
25 |
+
|
26 |
+
priority = {
|
27 |
+
"vicuna-13b": "aaaaaaa",
|
28 |
+
"koala-13b": "aaaaaab",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def get_conv_log_filename():
|
33 |
+
t = datetime.datetime.now()
|
34 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
35 |
+
return name
|
36 |
+
|
37 |
+
|
38 |
+
def get_model_list():
|
39 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
40 |
+
assert ret.status_code == 200
|
41 |
+
ret = requests.post(args.controller_url + "/list_models")
|
42 |
+
models = ret.json()["models"]
|
43 |
+
models.sort(key=lambda x: priority.get(x, x))
|
44 |
+
logger.info(f"Models: {models}")
|
45 |
+
return models
|
46 |
+
|
47 |
+
|
48 |
+
get_window_url_params = """
|
49 |
+
function() {
|
50 |
+
const params = new URLSearchParams(window.location.search);
|
51 |
+
url_params = Object.fromEntries(params);
|
52 |
+
console.log(url_params);
|
53 |
+
return url_params;
|
54 |
+
}
|
55 |
+
"""
|
56 |
+
|
57 |
+
|
58 |
+
def load_demo(url_params, request: gr.Request):
|
59 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
60 |
+
|
61 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
62 |
+
if "model" in url_params:
|
63 |
+
model = url_params["model"]
|
64 |
+
if model in models:
|
65 |
+
dropdown_update = gr.Dropdown.update(
|
66 |
+
value=model, visible=True)
|
67 |
+
|
68 |
+
state = default_conversation.copy()
|
69 |
+
return state, dropdown_update
|
70 |
+
|
71 |
+
|
72 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
73 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
74 |
+
models = get_model_list()
|
75 |
+
state = default_conversation.copy()
|
76 |
+
dropdown_update = gr.Dropdown.update(
|
77 |
+
choices=models,
|
78 |
+
value=models[0] if len(models) > 0 else ""
|
79 |
+
)
|
80 |
+
return state, dropdown_update
|
81 |
+
|
82 |
+
|
83 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
84 |
+
with open(get_conv_log_filename(), "a") as fout:
|
85 |
+
data = {
|
86 |
+
"tstamp": round(time.time(), 4),
|
87 |
+
"type": vote_type,
|
88 |
+
"model": model_selector,
|
89 |
+
"state": state.dict(),
|
90 |
+
"ip": request.client.host,
|
91 |
+
}
|
92 |
+
fout.write(json.dumps(data) + "\n")
|
93 |
+
|
94 |
+
|
95 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
96 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
97 |
+
vote_last_response(state, "upvote", model_selector, request)
|
98 |
+
return ("",) + (disable_btn,) * 3
|
99 |
+
|
100 |
+
|
101 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
102 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
103 |
+
vote_last_response(state, "downvote", model_selector, request)
|
104 |
+
return ("",) + (disable_btn,) * 3
|
105 |
+
|
106 |
+
|
107 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
108 |
+
logger.info(f"flag. ip: {request.client.host}")
|
109 |
+
vote_last_response(state, "flag", model_selector, request)
|
110 |
+
return ("",) + (disable_btn,) * 3
|
111 |
+
|
112 |
+
|
113 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
114 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
115 |
+
state.messages[-1][-1] = None
|
116 |
+
prev_human_msg = state.messages[-2]
|
117 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
118 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
119 |
+
state.skip_next = False
|
120 |
+
# (state, chatbot, textbox, imagebox, videobox, upvote, downvote, flag, generate, clear)
|
121 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
122 |
+
|
123 |
+
|
124 |
+
def clear_history(request: gr.Request):
|
125 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
126 |
+
state = default_conversation.copy()
|
127 |
+
# (state, chatbot, textbox, imagebox, videobox, upvote, downvote, flag, generate, clear)
|
128 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
129 |
+
|
130 |
+
|
131 |
+
def add_text_ori(state, text, image, video, image_process_mode, request: gr.Request):
|
132 |
+
# note: imagebox itself is PIL object while videobox is filepath
|
133 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
134 |
+
if len(text) <= 0 and image is None:
|
135 |
+
state.skip_next = True
|
136 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
137 |
+
if args.moderate:
|
138 |
+
flagged = violates_moderation(text)
|
139 |
+
if flagged:
|
140 |
+
state.skip_next = True
|
141 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
142 |
+
no_change_btn,) * 5
|
143 |
+
assert image is None or video is None, "Please don't feed image and video inputs at the same time!!!"
|
144 |
+
text = text[:1536] # Hard cut-off
|
145 |
+
if image is not None:
|
146 |
+
# here image is the PIL object itself
|
147 |
+
text = text[:1200] # Hard cut-off for images
|
148 |
+
if '<image>' not in text:
|
149 |
+
# text = '<Image><image></Image>' + text
|
150 |
+
text = text + '\n<image>'
|
151 |
+
text = (text, image, image_process_mode)
|
152 |
+
if len(state.get_images(return_pil=True)) > 0:
|
153 |
+
state = default_conversation.copy()
|
154 |
+
state.modality = "image"
|
155 |
+
if video is not None:
|
156 |
+
print("Video box:", video)
|
157 |
+
# here video is the file path of video
|
158 |
+
text = text[:1200] # Hard cut-off for images
|
159 |
+
if '<video>' not in text:
|
160 |
+
# text = '<Image><image></Image>' + text
|
161 |
+
text = text + '\n<video>'
|
162 |
+
text = (text, video, image_process_mode)
|
163 |
+
if len(state.get_videos(return_pil=True)) > 0:
|
164 |
+
state = default_conversation.copy()
|
165 |
+
state.modality = "video"
|
166 |
+
print("Set modality as video...")
|
167 |
+
state.append_message(state.roles[0], text)
|
168 |
+
state.append_message(state.roles[1], None)
|
169 |
+
state.skip_next = False
|
170 |
+
# (state, chatbot, textbox, imagebox, videobox, upvote, downvote, flag, generate, clear)
|
171 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
172 |
+
|
173 |
+
|
174 |
+
def add_text(state, text, image, video, image_process_mode, request: gr.Request):
|
175 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
176 |
+
|
177 |
+
# if input is new video or image ,reset the state
|
178 |
+
if image is not None or video is not None:
|
179 |
+
state = default_conversation.copy()
|
180 |
+
|
181 |
+
if len(text) <= 0 and image is None and video is None:
|
182 |
+
state.skip_next = True
|
183 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
|
184 |
+
|
185 |
+
if args.moderate:
|
186 |
+
flagged = violates_moderation(text)
|
187 |
+
if flagged:
|
188 |
+
state.skip_next = True
|
189 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
|
190 |
+
|
191 |
+
# process the input video
|
192 |
+
if video is not None:
|
193 |
+
text = text[:1200] #
|
194 |
+
if '<video>' not in text:
|
195 |
+
text = text + '\n<video>'
|
196 |
+
text = (text, video, image_process_mode)
|
197 |
+
state.modality = "video"
|
198 |
+
# process the input image
|
199 |
+
elif image is not None:
|
200 |
+
text = text[:1200] #
|
201 |
+
if '<image>' not in text:
|
202 |
+
text = text + '\n<image>'
|
203 |
+
text = (text, image, image_process_mode)
|
204 |
+
state.modality = "image"
|
205 |
+
elif state.modality == "image" and len(text)>0:
|
206 |
+
state.modality = "image_text"
|
207 |
+
text = text[:1536] # Hard cut-off
|
208 |
+
elif state.modality == "video" and len(text)>0:
|
209 |
+
state.modality = "video_text"
|
210 |
+
text = text[:1536] # Hard cut-off
|
211 |
+
|
212 |
+
state.append_message(state.roles[0], text)
|
213 |
+
state.append_message(state.roles[1], None)
|
214 |
+
state.skip_next = False
|
215 |
+
|
216 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
217 |
+
|
218 |
+
|
219 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
220 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
221 |
+
start_tstamp = time.time()
|
222 |
+
model_name = model_selector
|
223 |
+
|
224 |
+
if state.skip_next:
|
225 |
+
# This generate call is skipped due to invalid inputs
|
226 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
227 |
+
return
|
228 |
+
|
229 |
+
if len(state.messages) == state.offset + 2:
|
230 |
+
# First round of conversation
|
231 |
+
if "llava" in model_name.lower():
|
232 |
+
if 'llama-2' in model_name.lower():
|
233 |
+
template_name = "llava_llama2"
|
234 |
+
elif "v1" in model_name.lower():
|
235 |
+
if 'mmtag' in model_name.lower():
|
236 |
+
template_name = "v1_mmtag"
|
237 |
+
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
238 |
+
template_name = "v1_mmtag"
|
239 |
+
else:
|
240 |
+
template_name = "llava_v1"
|
241 |
+
else:
|
242 |
+
if 'mmtag' in model_name.lower():
|
243 |
+
template_name = "v0_mmtag"
|
244 |
+
elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
245 |
+
template_name = "v0_mmtag"
|
246 |
+
else:
|
247 |
+
template_name = "llava_v0"
|
248 |
+
elif "llama-2" in model_name:
|
249 |
+
template_name = "llama2"
|
250 |
+
else:
|
251 |
+
template_name = "vicuna_v1"
|
252 |
+
template_name = "llava_v1"
|
253 |
+
new_state = conv_templates[template_name].copy()
|
254 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
255 |
+
new_state.append_message(new_state.roles[1], None)
|
256 |
+
new_state.modality = state.modality
|
257 |
+
state = new_state
|
258 |
+
|
259 |
+
# Query worker address
|
260 |
+
controller_url = args.controller_url
|
261 |
+
ret = requests.post(controller_url + "/get_worker_address",
|
262 |
+
json={"model": model_name})
|
263 |
+
worker_addr = ret.json()["address"]
|
264 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
265 |
+
|
266 |
+
# No available worker
|
267 |
+
if worker_addr == "":
|
268 |
+
state.messages[-1][-1] = server_error_msg
|
269 |
+
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
270 |
+
return
|
271 |
+
|
272 |
+
# Construct prompt
|
273 |
+
prompt = state.get_prompt()
|
274 |
+
if state.modality == "image" or state.modality == "image_text":
|
275 |
+
all_images = state.get_images(return_pil=True) # return PIL.Image object
|
276 |
+
elif state.modality == "video" or state.modality == "video_text":
|
277 |
+
all_images = state.get_videos(return_pil=True) # return video frames where each frame is a PIL.Image object
|
278 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
279 |
+
for idx, (image, hash) in enumerate(zip(all_images, all_image_hash)):
|
280 |
+
t = datetime.datetime.now()
|
281 |
+
if state.modality == "image" or state.modality == "image_text":
|
282 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
283 |
+
elif state.modality == "video" or state.modality == "video_text":
|
284 |
+
filename = os.path.join(LOGDIR, "serve_videos", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}_{idx}.jpg")
|
285 |
+
if not os.path.isfile(filename):
|
286 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
287 |
+
image.save(filename)
|
288 |
+
|
289 |
+
# Make requests
|
290 |
+
pload = {
|
291 |
+
"model": model_name,
|
292 |
+
"prompt": prompt,
|
293 |
+
"temperature": float(temperature),
|
294 |
+
"top_p": float(top_p),
|
295 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
296 |
+
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE] else state.sep2,
|
297 |
+
#"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
298 |
+
"images": f'List of {len(all_image_hash)} images: {all_image_hash}',
|
299 |
+
}
|
300 |
+
logger.info(f"==== request ====\n{pload}")
|
301 |
+
|
302 |
+
if state.modality == "image" or state.modality == "image_text":
|
303 |
+
pload['images'] = state.get_images()
|
304 |
+
elif state.modality == "video" or state.modality == "video_text":
|
305 |
+
pload['images'] = state.get_videos()
|
306 |
+
|
307 |
+
state.messages[-1][-1] = "▌"
|
308 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
309 |
+
|
310 |
+
try:
|
311 |
+
# Stream output
|
312 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
313 |
+
headers=headers, json=pload, stream=True, timeout=10)
|
314 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
315 |
+
if chunk:
|
316 |
+
data = json.loads(chunk.decode())
|
317 |
+
if data["error_code"] == 0:
|
318 |
+
output = data["text"][len(prompt):].strip()
|
319 |
+
state.messages[-1][-1] = output + "▌"
|
320 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
321 |
+
else:
|
322 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
323 |
+
state.messages[-1][-1] = output
|
324 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
325 |
+
return
|
326 |
+
time.sleep(0.03)
|
327 |
+
except requests.exceptions.RequestException as e:
|
328 |
+
state.messages[-1][-1] = server_error_msg
|
329 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
330 |
+
return
|
331 |
+
|
332 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
333 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
334 |
+
|
335 |
+
finish_tstamp = time.time()
|
336 |
+
logger.info(f"{output}")
|
337 |
+
|
338 |
+
with open(get_conv_log_filename(), "a") as fout:
|
339 |
+
data = {
|
340 |
+
"tstamp": round(finish_tstamp, 4),
|
341 |
+
"type": "chat",
|
342 |
+
"model": model_name,
|
343 |
+
"start": round(start_tstamp, 4),
|
344 |
+
"finish": round(start_tstamp, 4),
|
345 |
+
#"state": state.dict(),
|
346 |
+
"images": all_image_hash,
|
347 |
+
"ip": request.client.host,
|
348 |
+
}
|
349 |
+
fout.write(json.dumps(data) + "\n")
|
350 |
+
|
351 |
+
title_markdown = ("""
|
352 |
+
# The publicl release of VideoLLaMA2
|
353 |
+
""")
|
354 |
+
|
355 |
+
tos_markdown = ("""
|
356 |
+
### Terms of use
|
357 |
+
By using this service, users are required to agree to the following terms:
|
358 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
359 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
360 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
361 |
+
""")
|
362 |
+
|
363 |
+
|
364 |
+
learn_more_markdown = ("""
|
365 |
+
### License
|
366 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
367 |
+
""")
|
368 |
+
|
369 |
+
block_css = """
|
370 |
+
|
371 |
+
#buttons button {
|
372 |
+
min-width: min(120px,100%);
|
373 |
+
}
|
374 |
+
|
375 |
+
"""
|
376 |
+
|
377 |
+
def build_demo(embed_mode):
|
378 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
379 |
+
with gr.Blocks(title="Video-Llama", theme=gr.themes.Default(), css=block_css) as demo:
|
380 |
+
state = gr.State()
|
381 |
+
|
382 |
+
if not embed_mode:
|
383 |
+
gr.Markdown(title_markdown)
|
384 |
+
|
385 |
+
with gr.Row():
|
386 |
+
with gr.Column(scale=3):
|
387 |
+
with gr.Row(elem_id="model_selector_row"):
|
388 |
+
model_selector = gr.Dropdown(
|
389 |
+
choices=models,
|
390 |
+
value=models[0] if len(models) > 0 else "",
|
391 |
+
interactive=True,
|
392 |
+
show_label=False,
|
393 |
+
container=False)
|
394 |
+
|
395 |
+
imagebox = gr.Image(type="pil")
|
396 |
+
videobox = gr.Video()
|
397 |
+
image_process_mode = gr.Radio(
|
398 |
+
["Crop", "Resize", "Pad", "Default"],
|
399 |
+
value="Default",
|
400 |
+
label="Preprocess for non-square image", visible=False)
|
401 |
+
|
402 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
403 |
+
gr.Examples(examples=[
|
404 |
+
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
405 |
+
[f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
|
406 |
+
[f"{cur_dir}/examples/desert.jpg", "If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?"],
|
407 |
+
], inputs=[imagebox, textbox], label="Image examples")
|
408 |
+
|
409 |
+
# video example inputs
|
410 |
+
gr.Examples(examples=[
|
411 |
+
[f"{cur_dir}/examples/sample_demo_1.mp4", "Why is this video funny?"],
|
412 |
+
[f"{cur_dir}/examples/sample_demo_3.mp4", "Can you identify any safety hazards in this video?"],
|
413 |
+
[f"{cur_dir}/examples/1034346401.mp4", "What is this young woman doing?"]
|
414 |
+
], inputs=[videobox, textbox], label="Video examples")
|
415 |
+
#[f"{cur_dir}/examples/sample_demo_9.mp4", "Describe the video in detail and please do not generate repetitive content."]
|
416 |
+
|
417 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
418 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
419 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
420 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
421 |
+
|
422 |
+
with gr.Column(scale=8):
|
423 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Videollama2 Chatbot", height=550)
|
424 |
+
with gr.Row():
|
425 |
+
with gr.Column(scale=8):
|
426 |
+
textbox.render()
|
427 |
+
with gr.Column(scale=1, min_width=50):
|
428 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
429 |
+
with gr.Row(elem_id="buttons") as button_row:
|
430 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
431 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
432 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
433 |
+
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
434 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
435 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
436 |
+
|
437 |
+
if not embed_mode:
|
438 |
+
gr.Markdown(tos_markdown)
|
439 |
+
gr.Markdown(learn_more_markdown)
|
440 |
+
url_params = gr.JSON(visible=False)
|
441 |
+
|
442 |
+
# Register listeners
|
443 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
444 |
+
upvote_btn.click(upvote_last_response,
|
445 |
+
[state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
|
446 |
+
downvote_btn.click(downvote_last_response,
|
447 |
+
[state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
|
448 |
+
flag_btn.click(flag_last_response,
|
449 |
+
[state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
|
450 |
+
regenerate_btn.click(regenerate, [state, image_process_mode],
|
451 |
+
[state, chatbot, textbox, imagebox, videobox] + btn_list).then(
|
452 |
+
http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
|
453 |
+
[state, chatbot] + btn_list)
|
454 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, videobox] + btn_list)
|
455 |
+
|
456 |
+
textbox.submit(add_text, [state, textbox, imagebox, videobox, image_process_mode], [state, chatbot, textbox, imagebox, videobox] + btn_list
|
457 |
+
).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
|
458 |
+
[state, chatbot] + btn_list)
|
459 |
+
submit_btn.click(add_text, [state, textbox, imagebox, videobox, image_process_mode], [state, chatbot, textbox, imagebox, videobox] + btn_list
|
460 |
+
).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
|
461 |
+
[state, chatbot] + btn_list)
|
462 |
+
|
463 |
+
if args.model_list_mode == "once":
|
464 |
+
demo.load(load_demo, [url_params], [state, model_selector],
|
465 |
+
_js=get_window_url_params)
|
466 |
+
elif args.model_list_mode == "reload":
|
467 |
+
demo.load(load_demo_refresh_model_list, None, [state, model_selector])
|
468 |
+
else:
|
469 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
470 |
+
|
471 |
+
return demo
|
472 |
+
|
473 |
+
|
474 |
+
if __name__ == "__main__":
|
475 |
+
parser = argparse.ArgumentParser()
|
476 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
477 |
+
parser.add_argument("--port", type=int)
|
478 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
479 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
480 |
+
parser.add_argument("--model-list-mode", type=str, default="once",
|
481 |
+
choices=["once", "reload"])
|
482 |
+
parser.add_argument("--share", action="store_true")
|
483 |
+
parser.add_argument("--moderate", action="store_true")
|
484 |
+
parser.add_argument("--embed", action="store_true")
|
485 |
+
args = parser.parse_args()
|
486 |
+
logger.info(f"args: {args}")
|
487 |
+
|
488 |
+
models = get_model_list()
|
489 |
+
|
490 |
+
logger.info(args)
|
491 |
+
demo = build_demo(args.embed)
|
492 |
+
demo.queue(
|
493 |
+
concurrency_count=args.concurrency_count,
|
494 |
+
api_open=False
|
495 |
+
).launch(
|
496 |
+
server_name=args.host,
|
497 |
+
server_port=args.port,
|
498 |
+
share=args.share
|
499 |
+
)
|
gradio_web_server_adhoc.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
import sys
|
10 |
+
sys.path.append('./')
|
11 |
+
from videollama2 import model_init, mm_infer
|
12 |
+
from videollama2.utils import disable_torch_init
|
13 |
+
|
14 |
+
|
15 |
+
title_markdown = ("""
|
16 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
17 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
|
18 |
+
<img src="https://s2.loli.net/2024/06/03/D3NeXHWy5az9tmT.png" alt="VideoLLaMA 2 🔥🚀🔥" style="max-width: 120px; height: auto;">
|
19 |
+
</a>
|
20 |
+
<div>
|
21 |
+
<h1 >VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs</h1>
|
22 |
+
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
|
23 |
+
</div>
|
24 |
+
</div>
|
25 |
+
|
26 |
+
|
27 |
+
<div align="center">
|
28 |
+
<div style="display:flex; gap: 0.25rem; margin-top: 10px;" align="center">
|
29 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2"><img src='https://img.shields.io/badge/Github-VideoLLaMA2-9C276A'></a>
|
30 |
+
<a href="https://arxiv.org/pdf/2406.07476.pdf"><img src="https://img.shields.io/badge/Arxiv-2406.07476-AD1C18"></a>
|
31 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2/stargazers"><img src="https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA2.svg?style=social"></a>
|
32 |
+
</div>
|
33 |
+
</div>
|
34 |
+
""")
|
35 |
+
|
36 |
+
|
37 |
+
block_css = """
|
38 |
+
#buttons button {
|
39 |
+
min-width: min(120px,100%);
|
40 |
+
color: #9C276A
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
|
45 |
+
tos_markdown = ("""
|
46 |
+
### Terms of use
|
47 |
+
By using this service, users are required to agree to the following terms:
|
48 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
49 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
50 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
51 |
+
""")
|
52 |
+
|
53 |
+
|
54 |
+
learn_more_markdown = ("""
|
55 |
+
### License
|
56 |
+
This project is released under the Apache 2.0 license as found in the LICENSE file. The service is a research preview intended for non-commercial use ONLY, subject to the model Licenses of LLaMA and Mistral, Terms of Use of the data generated by OpenAI, and Privacy Practices of ShareGPT. Please get in touch with us if you find any potential violations.
|
57 |
+
""")
|
58 |
+
|
59 |
+
|
60 |
+
plum_color = gr.themes.colors.Color(
|
61 |
+
name='plum',
|
62 |
+
c50='#F8E4EF',
|
63 |
+
c100='#E9D0DE',
|
64 |
+
c200='#DABCCD',
|
65 |
+
c300='#CBA8BC',
|
66 |
+
c400='#BC94AB',
|
67 |
+
c500='#AD809A',
|
68 |
+
c600='#9E6C89',
|
69 |
+
c700='#8F5878',
|
70 |
+
c800='#804467',
|
71 |
+
c900='#713056',
|
72 |
+
c950='#662647',
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
class Chat:
|
77 |
+
|
78 |
+
def __init__(self, model_path, load_8bit=False, load_4bit=False):
|
79 |
+
disable_torch_init()
|
80 |
+
|
81 |
+
self.model, self.processor, self.tokenizer = model_init(model_path, load_8bit=load_8bit, load_4bit=load_4bit)
|
82 |
+
|
83 |
+
@spaces.GPU(duration=120)
|
84 |
+
@torch.inference_mode()
|
85 |
+
def generate(self, data: list, message, temperature, top_p, max_output_tokens):
|
86 |
+
# TODO: support multiple turns of conversation.
|
87 |
+
assert len(data) == 1
|
88 |
+
|
89 |
+
tensor, modal = data[0]
|
90 |
+
response = mm_infer(tensor, message, self.model, self.tokenizer, modal=modal.strip('<>'),
|
91 |
+
do_sample=True if temperature > 0.0 else False,
|
92 |
+
temperature=temperature,
|
93 |
+
top_p=top_p,
|
94 |
+
max_new_tokens=max_output_tokens)
|
95 |
+
|
96 |
+
return response
|
97 |
+
|
98 |
+
|
99 |
+
@spaces.GPU(duration=120)
|
100 |
+
def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
|
101 |
+
data = []
|
102 |
+
|
103 |
+
processor = handler.processor
|
104 |
+
try:
|
105 |
+
if image is not None:
|
106 |
+
data.append((processor['image'](image).to(handler.model.device, dtype=dtype), '<image>'))
|
107 |
+
elif video is not None:
|
108 |
+
data.append((processor['video'](video).to(handler.model.device, dtype=dtype), '<video>'))
|
109 |
+
elif image is None and video is None:
|
110 |
+
data.append((None, '<text>'))
|
111 |
+
else:
|
112 |
+
raise NotImplementedError("Not support image and video at the same time")
|
113 |
+
except Exception as e:
|
114 |
+
traceback.print_exc()
|
115 |
+
return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot
|
116 |
+
|
117 |
+
assert len(message) % 2 == 0, "The message should be a pair of user and system message."
|
118 |
+
|
119 |
+
show_images = ""
|
120 |
+
if image is not None:
|
121 |
+
show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
|
122 |
+
if video is not None:
|
123 |
+
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
|
124 |
+
|
125 |
+
one_turn_chat = [textbox_in, None]
|
126 |
+
|
127 |
+
# 1. first run case
|
128 |
+
if len(chatbot) == 0:
|
129 |
+
one_turn_chat[0] += "\n" + show_images
|
130 |
+
# 2. not first run case
|
131 |
+
else:
|
132 |
+
# scanning the last image or video
|
133 |
+
length = len(chatbot)
|
134 |
+
for i in range(length - 1, -1, -1):
|
135 |
+
previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
|
136 |
+
previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
|
137 |
+
|
138 |
+
if len(previous_image) > 0:
|
139 |
+
previous_image = previous_image[-1]
|
140 |
+
# 2.1 new image append or pure text input will start a new conversation
|
141 |
+
if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
|
142 |
+
message.clear()
|
143 |
+
one_turn_chat[0] += "\n" + show_images
|
144 |
+
break
|
145 |
+
elif len(previous_video) > 0:
|
146 |
+
previous_video = previous_video[-1]
|
147 |
+
# 2.2 new video append or pure text input will start a new conversation
|
148 |
+
if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
|
149 |
+
message.clear()
|
150 |
+
one_turn_chat[0] += "\n" + show_images
|
151 |
+
break
|
152 |
+
|
153 |
+
message.append({'role': 'user', 'content': textbox_in})
|
154 |
+
text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
|
155 |
+
message.append({'role': 'assistant', 'content': text_en_out})
|
156 |
+
|
157 |
+
one_turn_chat[1] = text_en_out
|
158 |
+
chatbot.append(one_turn_chat)
|
159 |
+
|
160 |
+
return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot
|
161 |
+
|
162 |
+
|
163 |
+
def regenerate(message, chatbot):
|
164 |
+
message.pop(-1), message.pop(-1)
|
165 |
+
chatbot.pop(-1)
|
166 |
+
return message, chatbot
|
167 |
+
|
168 |
+
|
169 |
+
def clear_history(message, chatbot):
|
170 |
+
message.clear(), chatbot.clear()
|
171 |
+
return (gr.update(value=None, interactive=True),
|
172 |
+
gr.update(value=None, interactive=True),
|
173 |
+
message, chatbot,
|
174 |
+
gr.update(value=None, interactive=True))
|
175 |
+
|
176 |
+
|
177 |
+
# BUG of Zero Environment
|
178 |
+
# 1. The environment is fixed to torch>=2.0,<=2.2, gradio>=4.x.x
|
179 |
+
# 2. The operation or tensor which requires cuda are limited in those functions wrapped via spaces.GPU
|
180 |
+
# 3. The function can't return tensor or other cuda objects.
|
181 |
+
|
182 |
+
model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B-16F'
|
183 |
+
|
184 |
+
handler = Chat(model_path, load_8bit=False, load_4bit=True)
|
185 |
+
|
186 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
187 |
+
|
188 |
+
theme = gr.themes.Default(primary_hue=plum_color)
|
189 |
+
# theme.update_color("primary", plum_color.c500)
|
190 |
+
theme.set(slider_color="#9C276A")
|
191 |
+
theme.set(block_title_text_color="#9C276A")
|
192 |
+
theme.set(block_label_text_color="#9C276A")
|
193 |
+
theme.set(button_primary_text_color="#9C276A")
|
194 |
+
# theme.set(button_secondary_text_color="*neutral_800")
|
195 |
+
|
196 |
+
|
197 |
+
with gr.Blocks(title='VideoLLaMA 2 🔥🚀🔥', theme=theme, css=block_css) as demo:
|
198 |
+
gr.Markdown(title_markdown)
|
199 |
+
message = gr.State([])
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column(scale=3):
|
203 |
+
image = gr.Image(label="Input Image", type="filepath")
|
204 |
+
video = gr.Video(label="Input Video")
|
205 |
+
|
206 |
+
with gr.Accordion("Parameters", open=True) as parameter_row:
|
207 |
+
# num_beams = gr.Slider(
|
208 |
+
# minimum=1,
|
209 |
+
# maximum=10,
|
210 |
+
# value=1,
|
211 |
+
# step=1,
|
212 |
+
# interactive=True,
|
213 |
+
# label="beam search numbers",
|
214 |
+
# )
|
215 |
+
|
216 |
+
temperature = gr.Slider(
|
217 |
+
minimum=0.1,
|
218 |
+
maximum=1.0,
|
219 |
+
value=0.2,
|
220 |
+
step=0.1,
|
221 |
+
interactive=True,
|
222 |
+
label="Temperature",
|
223 |
+
)
|
224 |
+
|
225 |
+
top_p = gr.Slider(
|
226 |
+
minimum=0.0,
|
227 |
+
maximum=1.0,
|
228 |
+
value=0.7,
|
229 |
+
step=0.1,
|
230 |
+
interactive=True,
|
231 |
+
label="Top P",
|
232 |
+
)
|
233 |
+
|
234 |
+
max_output_tokens = gr.Slider(
|
235 |
+
minimum=64,
|
236 |
+
maximum=1024,
|
237 |
+
value=512,
|
238 |
+
step=64,
|
239 |
+
interactive=True,
|
240 |
+
label="Max output tokens",
|
241 |
+
)
|
242 |
+
|
243 |
+
with gr.Column(scale=7):
|
244 |
+
chatbot = gr.Chatbot(label="VideoLLaMA 2", bubble_full_width=True, height=750)
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Column(scale=8):
|
247 |
+
textbox.render()
|
248 |
+
with gr.Column(scale=1, min_width=50):
|
249 |
+
submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
|
250 |
+
with gr.Row(elem_id="buttons") as button_row:
|
251 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
252 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
|
253 |
+
# flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
|
254 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
255 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
256 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
257 |
+
|
258 |
+
with gr.Row():
|
259 |
+
with gr.Column():
|
260 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
261 |
+
gr.Examples(
|
262 |
+
examples=[
|
263 |
+
[
|
264 |
+
f"{cur_dir}/examples/extreme_ironing.jpg",
|
265 |
+
"What happens in this image?",
|
266 |
+
],
|
267 |
+
[
|
268 |
+
f"{cur_dir}/examples/waterview.jpg",
|
269 |
+
"What are the things I should be cautious about when I visit here?",
|
270 |
+
],
|
271 |
+
[
|
272 |
+
f"{cur_dir}/examples/desert.jpg",
|
273 |
+
"If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?",
|
274 |
+
],
|
275 |
+
],
|
276 |
+
inputs=[image, textbox],
|
277 |
+
)
|
278 |
+
with gr.Column():
|
279 |
+
gr.Examples(
|
280 |
+
examples=[
|
281 |
+
[
|
282 |
+
f"{cur_dir}/../../assets/cat_and_chicken.mp4",
|
283 |
+
"What happens in this video?",
|
284 |
+
],
|
285 |
+
[
|
286 |
+
f"{cur_dir}/../../assets/sora.mp4",
|
287 |
+
"Please describe this video.",
|
288 |
+
],
|
289 |
+
[
|
290 |
+
f"{cur_dir}/examples/sample_demo_1.mp4",
|
291 |
+
"What does the baby do?",
|
292 |
+
],
|
293 |
+
],
|
294 |
+
inputs=[video, textbox],
|
295 |
+
)
|
296 |
+
|
297 |
+
gr.Markdown(tos_markdown)
|
298 |
+
gr.Markdown(learn_more_markdown)
|
299 |
+
|
300 |
+
submit_btn.click(
|
301 |
+
generate,
|
302 |
+
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
303 |
+
[image, video, message, chatbot])
|
304 |
+
|
305 |
+
regenerate_btn.click(
|
306 |
+
regenerate,
|
307 |
+
[message, chatbot],
|
308 |
+
[message, chatbot]).then(
|
309 |
+
generate,
|
310 |
+
[image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
|
311 |
+
[image, video, message, chatbot])
|
312 |
+
|
313 |
+
clear_btn.click(
|
314 |
+
clear_history,
|
315 |
+
[message, chatbot],
|
316 |
+
[image, video, message, chatbot, textbox])
|
317 |
+
|
318 |
+
demo.launch(share = True)
|
model_worker.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import uuid
|
8 |
+
import asyncio
|
9 |
+
import requests
|
10 |
+
import argparse
|
11 |
+
import threading
|
12 |
+
from threading import Thread
|
13 |
+
from functools import partial
|
14 |
+
from typing import Iterator, List, Optional, Tuple
|
15 |
+
|
16 |
+
import uvicorn
|
17 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
18 |
+
from fastapi.responses import StreamingResponse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import decord
|
22 |
+
import numpy as np
|
23 |
+
from PIL import Image
|
24 |
+
from decord import VideoReader, cpu
|
25 |
+
from transformers import TextIteratorStreamer
|
26 |
+
|
27 |
+
from videollama2.constants import WORKER_HEART_BEAT_INTERVAL
|
28 |
+
from videollama2.utils import (build_logger, server_error_msg, pretty_print_semaphore)
|
29 |
+
from videollama2.model.builder import load_pretrained_model
|
30 |
+
from videollama2.mm_utils import process_images, process_videos, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria, tokenizer_MMODAL_token
|
31 |
+
from videollama2.mm_utils import chunk_list, frame_expansion
|
32 |
+
from videollama2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, NUM_FRAMES, MMODAL_TOKEN_INDEX
|
33 |
+
|
34 |
+
|
35 |
+
GB = 1 << 30
|
36 |
+
|
37 |
+
worker_id = str(uuid.uuid4())[:6]
|
38 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
39 |
+
global_counter = 0
|
40 |
+
|
41 |
+
model_semaphore = None
|
42 |
+
|
43 |
+
|
44 |
+
# variable_content = os.getenv('MY_VARIABLE', '')
|
45 |
+
# KEYWORDS_LIST = set(variable_content.split('\n'))
|
46 |
+
KEYWORDS_LIST = []
|
47 |
+
path = 'assets/keywords.txt'
|
48 |
+
if os.path.exists(path):
|
49 |
+
with open(path, 'r', encoding='utf-8') as file:
|
50 |
+
for line in file:
|
51 |
+
|
52 |
+
KEYWORDS_LIST.append(line.strip())
|
53 |
+
else:
|
54 |
+
KEYWORDS_LIST = []
|
55 |
+
|
56 |
+
|
57 |
+
KEYWORD_BLOCK_MESSAGE2 = "The output contains political, erotic and other unsafe content that violates local laws. Please re-enter your question."
|
58 |
+
KEYWORD_BLOCK_MESSAGE1 = "Your input question contains political, erotic and other unsafe content that violates local laws. Please re-enter your question."
|
59 |
+
STREAM_CHECK_MULTIPLE = 20
|
60 |
+
|
61 |
+
|
62 |
+
def heart_beat_worker(controller):
|
63 |
+
|
64 |
+
while True:
|
65 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
66 |
+
controller.send_heart_beat()
|
67 |
+
|
68 |
+
|
69 |
+
def safety_check(text, history=None, ) -> Optional[str]:
|
70 |
+
|
71 |
+
if len(KEYWORDS_LIST) > 0 and any(x in text.lower() for x in KEYWORDS_LIST):
|
72 |
+
print('############')
|
73 |
+
return KEYWORD_BLOCK_MESSAGE2
|
74 |
+
|
75 |
+
return None
|
76 |
+
|
77 |
+
|
78 |
+
def input_safety_check(text) -> Optional[str]:
|
79 |
+
if len(KEYWORDS_LIST) > 0 and any(x in text.lower() for x in KEYWORDS_LIST):
|
80 |
+
print('######## Input keyword alarm triggered:', text)
|
81 |
+
return KEYWORD_BLOCK_MESSAGE1
|
82 |
+
return None
|
83 |
+
|
84 |
+
|
85 |
+
class ModelWorker:
|
86 |
+
|
87 |
+
def __init__(self, controller_addr, worker_addr,
|
88 |
+
worker_id, no_register,
|
89 |
+
model_path, model_base, model_name,
|
90 |
+
load_8bit, load_4bit, device):
|
91 |
+
self.controller_addr = controller_addr
|
92 |
+
self.worker_addr = worker_addr
|
93 |
+
self.worker_id = worker_id
|
94 |
+
self.model_path = model_path
|
95 |
+
if model_path.endswith("/"):
|
96 |
+
model_path = model_path[:-1]
|
97 |
+
if model_name is None:
|
98 |
+
model_paths = model_path.split("/")
|
99 |
+
if model_paths[-1].startswith('checkpoint-'):
|
100 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
101 |
+
else:
|
102 |
+
self.model_name = model_paths[-1]
|
103 |
+
else:
|
104 |
+
self.model_name = model_name
|
105 |
+
|
106 |
+
self.device = device
|
107 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
108 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
109 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
110 |
+
self.is_multimodal = 'videollama2' in self.model_name.lower() or 'vlb' in self.model_name.lower()
|
111 |
+
|
112 |
+
if not no_register:
|
113 |
+
self.register_to_controller()
|
114 |
+
self.heart_beat_thread = threading.Thread(
|
115 |
+
target=heart_beat_worker, args=(self,))
|
116 |
+
self.heart_beat_thread.start()
|
117 |
+
|
118 |
+
def register_to_controller(self):
|
119 |
+
logger.info("Register to controller")
|
120 |
+
|
121 |
+
url = self.controller_addr + "/register_worker"
|
122 |
+
data = {
|
123 |
+
"worker_name": self.worker_addr,
|
124 |
+
"check_heart_beat": True,
|
125 |
+
"worker_status": self.get_status()
|
126 |
+
}
|
127 |
+
r = requests.post(url, json=data)
|
128 |
+
assert r.status_code == 200
|
129 |
+
|
130 |
+
def send_heart_beat(self):
|
131 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
132 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
133 |
+
f"global_counter: {global_counter}")
|
134 |
+
|
135 |
+
url = self.controller_addr + "/receive_heart_beat"
|
136 |
+
|
137 |
+
while True:
|
138 |
+
try:
|
139 |
+
ret = requests.post(url, json={
|
140 |
+
"worker_name": self.worker_addr,
|
141 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
142 |
+
exist = ret.json()["exist"]
|
143 |
+
break
|
144 |
+
except requests.exceptions.RequestException as e:
|
145 |
+
logger.error(f"heart beat error: {e}")
|
146 |
+
time.sleep(5)
|
147 |
+
|
148 |
+
if not exist:
|
149 |
+
self.register_to_controller()
|
150 |
+
|
151 |
+
def get_queue_length(self):
|
152 |
+
if model_semaphore is None:
|
153 |
+
return 0
|
154 |
+
else:
|
155 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
156 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
157 |
+
|
158 |
+
def get_status(self):
|
159 |
+
return {
|
160 |
+
"model_names": [self.model_name],
|
161 |
+
"speed": 1,
|
162 |
+
"queue_length": self.get_queue_length(),
|
163 |
+
}
|
164 |
+
|
165 |
+
@torch.inference_mode()
|
166 |
+
def generate_stream(self, params):
|
167 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
168 |
+
|
169 |
+
prompt = params["prompt"]
|
170 |
+
ori_prompt = prompt
|
171 |
+
images_or_videos = params.get("images", None)
|
172 |
+
#print("Input images:", images_or_videos)
|
173 |
+
num_image_tokens = 0
|
174 |
+
modal_list = []
|
175 |
+
if images_or_videos is not None and len(images_or_videos) and self.is_multimodal:
|
176 |
+
if len(images_or_videos) > 0:
|
177 |
+
if len(images_or_videos) != prompt.count(DEFAULT_IMAGE_TOKEN) and len(images_or_videos) != (prompt.count(DEFAULT_VIDEO_TOKEN)):
|
178 |
+
raise ValueError("Number of images/videos does not match number of <image>/<video> tokens in prompt")
|
179 |
+
|
180 |
+
try:
|
181 |
+
print("Load image...")
|
182 |
+
images_or_videos = [load_image_from_base64(image) for image in images_or_videos]
|
183 |
+
images_or_videos = process_images(images_or_videos, image_processor, model.config)
|
184 |
+
|
185 |
+
modal_list = ["image"]
|
186 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
187 |
+
modal_token_index = MMODAL_TOKEN_INDEX["IMAGE"]
|
188 |
+
except:
|
189 |
+
print("Load video instead...")
|
190 |
+
decord_vr = VideoReader(uri=images_or_videos[0], ctx=cpu(0))
|
191 |
+
duration = len(decord_vr)
|
192 |
+
if not "use_taug" in self.model_path:
|
193 |
+
frame_id_list = np.linspace(0, duration-1, 8, dtype=int)
|
194 |
+
video_frames = decord_vr.get_batch(frame_id_list).asnumpy()
|
195 |
+
images_or_videos = process_videos(video_frames, image_processor, model.config)
|
196 |
+
else:
|
197 |
+
print("Temporal augmentation activated!!!")
|
198 |
+
frame_id_list = np.linspace(0, duration-1, 8 * 2 * 2, dtype=int)
|
199 |
+
video_data = decord_vr.get_batch(frame_id_list)
|
200 |
+
video_frames = [Image.fromarray(f) for f in video_data.asnumpy()]
|
201 |
+
chunked_video_frames = chunk_list(video_frames, 2*2)
|
202 |
+
expanded_video_frames = [frame_expansion(frame_list, 2) for frame_list in chunked_video_frames]
|
203 |
+
images_or_videos = process_videos(expanded_video_frames, image_processor, model.config)
|
204 |
+
|
205 |
+
# frame_id_list = np.linspace(0, duration-1, NUM_FRAMES, dtype=int)
|
206 |
+
# images_or_videos = decord_vr.get_batch(frame_id_list).asnumpy()
|
207 |
+
# images_or_videos = process_videos(images_or_videos, image_processor, model.config)
|
208 |
+
#print("images_or_videos.shape:", images_or_videos.shape)
|
209 |
+
modal_list = ["video"]
|
210 |
+
replace_token = DEFAULT_VIDEO_TOKEN
|
211 |
+
modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
|
212 |
+
|
213 |
+
if type(images_or_videos) is list:
|
214 |
+
images_or_videos = [image.to(self.model.device, dtype=torch.float16) for image in images_or_videos]
|
215 |
+
else:
|
216 |
+
images_or_videos = images_or_videos.to(self.model.device, dtype=torch.float16)
|
217 |
+
if modal_list[0] == "video":
|
218 |
+
print("Video:", images_or_videos.shape)
|
219 |
+
images_or_videos = [images_or_videos]
|
220 |
+
else:
|
221 |
+
print("Image:", images_or_videos.shape)
|
222 |
+
|
223 |
+
|
224 |
+
#image_sizes = [image.size for image in images_or_videos]
|
225 |
+
|
226 |
+
|
227 |
+
# if len(images_or_videos) % NUM_FRAMES == 0:
|
228 |
+
# images_or_videos = process_images(images_or_videos, image_processor, model.config)
|
229 |
+
# #images_or_videos = [image.to(self.model.device, dtype=torch.float16) for image in images_or_videos]
|
230 |
+
# #modal_list = ["image"] * len(images_or_videos)
|
231 |
+
# images_or_videos = images_or_videos.to(self.model.device, dtype=torch.float16)
|
232 |
+
# modal_list = ["video"]
|
233 |
+
# replace_token = DEFAULT_VIDEO_TOKEN
|
234 |
+
# else:
|
235 |
+
|
236 |
+
if getattr(self.model.config, 'mm_use_im_start_end', False):
|
237 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
238 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
239 |
+
|
240 |
+
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
241 |
+
else:
|
242 |
+
images = None
|
243 |
+
modal_list = []
|
244 |
+
image_args = {"images_or_videos": images_or_videos, "modal_list": modal_list}
|
245 |
+
else:
|
246 |
+
images = None
|
247 |
+
image_args = {}
|
248 |
+
print("image_args:", image_args)
|
249 |
+
temperature = float(params.get("temperature", 1.0))
|
250 |
+
top_p = float(params.get("top_p", 1.0))
|
251 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
252 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
253 |
+
stop_str = params.get("stop", None)
|
254 |
+
do_sample = True if temperature > 0.001 else False
|
255 |
+
|
256 |
+
#input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
257 |
+
# tokenizer for our video-llama beta
|
258 |
+
input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_token_index, return_tensors='pt').unsqueeze(0).to(self.device)
|
259 |
+
#print("Current prompt:", prompt)
|
260 |
+
#print("input_ids.shape:", input_ids.shape)
|
261 |
+
keywords = [stop_str]
|
262 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
263 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
264 |
+
|
265 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
266 |
+
|
267 |
+
if max_new_tokens < 1:
|
268 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
269 |
+
return
|
270 |
+
|
271 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
272 |
+
inputs=input_ids,
|
273 |
+
do_sample=do_sample,
|
274 |
+
temperature=temperature,
|
275 |
+
top_p=top_p,
|
276 |
+
max_new_tokens=max_new_tokens,
|
277 |
+
streamer=streamer,
|
278 |
+
stopping_criteria=[stopping_criteria],
|
279 |
+
use_cache=True,
|
280 |
+
**image_args
|
281 |
+
))
|
282 |
+
thread.start()
|
283 |
+
|
284 |
+
generated_text = ori_prompt
|
285 |
+
token_count = 0
|
286 |
+
for new_text in streamer:
|
287 |
+
generated_text += new_text
|
288 |
+
token_count += len(tokenizer.encode(new_text))
|
289 |
+
if token_count >= STREAM_CHECK_MULTIPLE:
|
290 |
+
safety_message = safety_check(generated_text)
|
291 |
+
if safety_message:
|
292 |
+
print('####### Keyword alarm triggered:', generated_text)
|
293 |
+
yield json.dumps({"text": safety_message , "error_code": 1}).encode() + b"\0"
|
294 |
+
return
|
295 |
+
token_count = 0 #
|
296 |
+
|
297 |
+
|
298 |
+
if generated_text.endswith(stop_str):
|
299 |
+
generated_text = generated_text[:-len(stop_str)]
|
300 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
301 |
+
|
302 |
+
def generate_stream_gate(self, params):
|
303 |
+
try:
|
304 |
+
input_text = params.get("prompt", "")
|
305 |
+
safety_message = input_safety_check(input_text)
|
306 |
+
if safety_message:
|
307 |
+
yield json.dumps({"text": safety_message, "error_code": 1}).encode() + b"\0"
|
308 |
+
return
|
309 |
+
|
310 |
+
for x in self.generate_stream(params):
|
311 |
+
yield x
|
312 |
+
except ValueError as e:
|
313 |
+
print("Caught ValueError:", e)
|
314 |
+
ret = {
|
315 |
+
"text": server_error_msg,
|
316 |
+
"error_code": 1,
|
317 |
+
}
|
318 |
+
yield json.dumps(ret).encode() + b"\0"
|
319 |
+
except torch.cuda.CudaError as e:
|
320 |
+
print("Caught torch.cuda.CudaError:", e)
|
321 |
+
ret = {
|
322 |
+
"text": server_error_msg,
|
323 |
+
"error_code": 1,
|
324 |
+
}
|
325 |
+
yield json.dumps(ret).encode() + b"\0"
|
326 |
+
except Exception as e:
|
327 |
+
print("Caught Unknown Error", e)
|
328 |
+
ret = {
|
329 |
+
"text": server_error_msg,
|
330 |
+
"error_code": 1,
|
331 |
+
}
|
332 |
+
yield json.dumps(ret).encode() + b"\0"
|
333 |
+
|
334 |
+
|
335 |
+
app = FastAPI()
|
336 |
+
|
337 |
+
|
338 |
+
def release_model_semaphore(fn=None):
|
339 |
+
model_semaphore.release()
|
340 |
+
if fn is not None:
|
341 |
+
fn()
|
342 |
+
|
343 |
+
|
344 |
+
@app.post("/worker_generate_stream")
|
345 |
+
async def generate_stream(request: Request):
|
346 |
+
global model_semaphore, global_counter
|
347 |
+
global_counter += 1
|
348 |
+
params = await request.json()
|
349 |
+
|
350 |
+
if model_semaphore is None:
|
351 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
352 |
+
await model_semaphore.acquire()
|
353 |
+
worker.send_heart_beat()
|
354 |
+
generator = worker.generate_stream_gate(params)
|
355 |
+
background_tasks = BackgroundTasks()
|
356 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
357 |
+
return StreamingResponse(generator, background=background_tasks)
|
358 |
+
|
359 |
+
|
360 |
+
@app.post("/worker_get_status")
|
361 |
+
async def get_status(request: Request):
|
362 |
+
return worker.get_status()
|
363 |
+
|
364 |
+
|
365 |
+
if __name__ == "__main__":
|
366 |
+
parser = argparse.ArgumentParser()
|
367 |
+
parser.add_argument("--host", type=str, default="localhost")
|
368 |
+
parser.add_argument("--port", type=int, default=21002)
|
369 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
370 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
371 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
372 |
+
parser.add_argument("--model-base", type=str, default=None)
|
373 |
+
parser.add_argument("--model-name", type=str)
|
374 |
+
parser.add_argument("--device", type=str, default="cuda")
|
375 |
+
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
376 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
377 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
378 |
+
parser.add_argument("--no-register", action="store_true")
|
379 |
+
parser.add_argument("--load-8bit", action="store_true")
|
380 |
+
parser.add_argument("--load-4bit", action="store_true")
|
381 |
+
args = parser.parse_args()
|
382 |
+
logger.info(f"args: {args}")
|
383 |
+
|
384 |
+
if args.multi_modal:
|
385 |
+
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
386 |
+
|
387 |
+
worker = ModelWorker(args.controller_address,
|
388 |
+
args.worker_address,
|
389 |
+
worker_id,
|
390 |
+
args.no_register,
|
391 |
+
args.model_path,
|
392 |
+
args.model_base,
|
393 |
+
args.model_name,
|
394 |
+
args.load_8bit,
|
395 |
+
args.load_4bit,
|
396 |
+
args.device)
|
397 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
register_worker.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Manually register workers.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import requests
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--controller-address", type=str)
|
15 |
+
parser.add_argument("--worker-name", type=str)
|
16 |
+
parser.add_argument("--check-heart-beat", action="store_true")
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
url = args.controller_address + "/register_worker"
|
20 |
+
data = {
|
21 |
+
"worker_name": args.worker_name,
|
22 |
+
"check_heart_beat": args.check_heart_beat,
|
23 |
+
"worker_status": None,
|
24 |
+
}
|
25 |
+
r = requests.post(url, json=data)
|
26 |
+
assert r.status_code == 200
|
sglang_worker.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
+
import json
|
8 |
+
import time
|
9 |
+
import threading
|
10 |
+
import uuid
|
11 |
+
|
12 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
13 |
+
from fastapi.responses import StreamingResponse
|
14 |
+
import requests
|
15 |
+
import re
|
16 |
+
import uvicorn
|
17 |
+
from functools import partial
|
18 |
+
|
19 |
+
from llava.constants import WORKER_HEART_BEAT_INTERVAL
|
20 |
+
from llava.utils import (build_logger, server_error_msg,
|
21 |
+
pretty_print_semaphore)
|
22 |
+
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
|
23 |
+
from llava.constants import DEFAULT_IMAGE_TOKEN
|
24 |
+
|
25 |
+
import sglang as sgl
|
26 |
+
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
27 |
+
|
28 |
+
|
29 |
+
GB = 1 << 30
|
30 |
+
|
31 |
+
worker_id = str(uuid.uuid4())[:6]
|
32 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
33 |
+
global_counter = 0
|
34 |
+
|
35 |
+
model_semaphore = None
|
36 |
+
|
37 |
+
|
38 |
+
def heart_beat_worker(controller):
|
39 |
+
while True:
|
40 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
41 |
+
controller.send_heart_beat()
|
42 |
+
|
43 |
+
|
44 |
+
@sgl.function
|
45 |
+
def pipeline(s, prompt, max_tokens):
|
46 |
+
for p in prompt:
|
47 |
+
if type(p) is str:
|
48 |
+
s += p
|
49 |
+
else:
|
50 |
+
s += sgl.image(p)
|
51 |
+
s += sgl.gen("response", max_tokens=max_tokens)
|
52 |
+
|
53 |
+
|
54 |
+
class ModelWorker:
|
55 |
+
def __init__(self, controller_addr, worker_addr, sgl_endpoint,
|
56 |
+
worker_id, no_register, model_name):
|
57 |
+
self.controller_addr = controller_addr
|
58 |
+
self.worker_addr = worker_addr
|
59 |
+
self.worker_id = worker_id
|
60 |
+
|
61 |
+
# Select backend
|
62 |
+
backend = RuntimeEndpoint(sgl_endpoint)
|
63 |
+
sgl.set_default_backend(backend)
|
64 |
+
model_path = backend.model_info["model_path"]
|
65 |
+
|
66 |
+
if model_path.endswith("/"):
|
67 |
+
model_path = model_path[:-1]
|
68 |
+
if model_name is None:
|
69 |
+
model_paths = model_path.split("/")
|
70 |
+
if model_paths[-1].startswith('checkpoint-'):
|
71 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
72 |
+
else:
|
73 |
+
self.model_name = model_paths[-1]
|
74 |
+
else:
|
75 |
+
self.model_name = model_name
|
76 |
+
|
77 |
+
logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
|
78 |
+
|
79 |
+
if not no_register:
|
80 |
+
self.register_to_controller()
|
81 |
+
self.heart_beat_thread = threading.Thread(
|
82 |
+
target=heart_beat_worker, args=(self,), daemon=True)
|
83 |
+
self.heart_beat_thread.start()
|
84 |
+
|
85 |
+
def register_to_controller(self):
|
86 |
+
logger.info("Register to controller")
|
87 |
+
|
88 |
+
url = self.controller_addr + "/register_worker"
|
89 |
+
data = {
|
90 |
+
"worker_name": self.worker_addr,
|
91 |
+
"check_heart_beat": True,
|
92 |
+
"worker_status": self.get_status()
|
93 |
+
}
|
94 |
+
r = requests.post(url, json=data)
|
95 |
+
assert r.status_code == 200
|
96 |
+
|
97 |
+
def send_heart_beat(self):
|
98 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
99 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
100 |
+
f"global_counter: {global_counter}")
|
101 |
+
|
102 |
+
url = self.controller_addr + "/receive_heart_beat"
|
103 |
+
|
104 |
+
while True:
|
105 |
+
try:
|
106 |
+
ret = requests.post(url, json={
|
107 |
+
"worker_name": self.worker_addr,
|
108 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
109 |
+
exist = ret.json()["exist"]
|
110 |
+
break
|
111 |
+
except requests.exceptions.RequestException as e:
|
112 |
+
logger.error(f"heart beat error: {e}")
|
113 |
+
time.sleep(5)
|
114 |
+
|
115 |
+
if not exist:
|
116 |
+
self.register_to_controller()
|
117 |
+
|
118 |
+
def get_queue_length(self):
|
119 |
+
if model_semaphore is None:
|
120 |
+
return 0
|
121 |
+
else:
|
122 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
123 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
124 |
+
|
125 |
+
def get_status(self):
|
126 |
+
return {
|
127 |
+
"model_names": [self.model_name],
|
128 |
+
"speed": 1,
|
129 |
+
"queue_length": self.get_queue_length(),
|
130 |
+
}
|
131 |
+
|
132 |
+
async def generate_stream(self, params):
|
133 |
+
ori_prompt = prompt = params["prompt"]
|
134 |
+
images = params.get("images", None)
|
135 |
+
if images is not None and len(images) > 0:
|
136 |
+
if len(images) > 0:
|
137 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
138 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
139 |
+
|
140 |
+
images = [load_image_from_base64(image) for image in images]
|
141 |
+
|
142 |
+
# FIXME: for image-start/end token
|
143 |
+
# replace_token = DEFAULT_IMAGE_TOKEN
|
144 |
+
# if getattr(self.model.config, 'mm_use_im_start_end', False):
|
145 |
+
# replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
146 |
+
# prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
147 |
+
prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
|
148 |
+
prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
|
149 |
+
prompt = []
|
150 |
+
for i in range(len(prompt_split)):
|
151 |
+
prompt.append(prompt_split[i])
|
152 |
+
if i < len(images):
|
153 |
+
prompt.append(images[i])
|
154 |
+
else:
|
155 |
+
prompt = [prompt]
|
156 |
+
|
157 |
+
temperature = float(params.get("temperature", 1.0))
|
158 |
+
top_p = float(params.get("top_p", 1.0))
|
159 |
+
# max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
160 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
161 |
+
stop_str = params.get("stop", None)
|
162 |
+
stop_str = [stop_str] if stop_str is not None else None
|
163 |
+
|
164 |
+
print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
|
165 |
+
state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
|
166 |
+
|
167 |
+
generated_text = ori_prompt
|
168 |
+
async for text_outputs in state.text_async_iter(var_name="response"):
|
169 |
+
generated_text += text_outputs
|
170 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
171 |
+
|
172 |
+
async def generate_stream_gate(self, params):
|
173 |
+
try:
|
174 |
+
async for x in self.generate_stream(params):
|
175 |
+
yield x
|
176 |
+
except ValueError as e:
|
177 |
+
print("Caught ValueError:", e)
|
178 |
+
ret = {
|
179 |
+
"text": server_error_msg,
|
180 |
+
"error_code": 1,
|
181 |
+
}
|
182 |
+
yield json.dumps(ret).encode() + b"\0"
|
183 |
+
except Exception as e:
|
184 |
+
print("Caught Unknown Error", e)
|
185 |
+
ret = {
|
186 |
+
"text": server_error_msg,
|
187 |
+
"error_code": 1,
|
188 |
+
}
|
189 |
+
yield json.dumps(ret).encode() + b"\0"
|
190 |
+
|
191 |
+
|
192 |
+
app = FastAPI()
|
193 |
+
|
194 |
+
|
195 |
+
def release_model_semaphore(fn=None):
|
196 |
+
model_semaphore.release()
|
197 |
+
if fn is not None:
|
198 |
+
fn()
|
199 |
+
|
200 |
+
|
201 |
+
@app.post("/worker_generate_stream")
|
202 |
+
async def generate_stream(request: Request):
|
203 |
+
global model_semaphore, global_counter
|
204 |
+
global_counter += 1
|
205 |
+
params = await request.json()
|
206 |
+
|
207 |
+
if model_semaphore is None:
|
208 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
209 |
+
await model_semaphore.acquire()
|
210 |
+
worker.send_heart_beat()
|
211 |
+
generator = worker.generate_stream_gate(params)
|
212 |
+
background_tasks = BackgroundTasks()
|
213 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
214 |
+
return StreamingResponse(generator, background=background_tasks)
|
215 |
+
|
216 |
+
|
217 |
+
@app.post("/worker_get_status")
|
218 |
+
async def get_status(request: Request):
|
219 |
+
return worker.get_status()
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
parser = argparse.ArgumentParser()
|
224 |
+
parser.add_argument("--host", type=str, default="localhost")
|
225 |
+
parser.add_argument("--port", type=int, default=21002)
|
226 |
+
parser.add_argument("--worker-address", type=str,
|
227 |
+
default="http://localhost:21002")
|
228 |
+
parser.add_argument("--controller-address", type=str,
|
229 |
+
default="http://localhost:21001")
|
230 |
+
parser.add_argument("--model-name", type=str)
|
231 |
+
parser.add_argument("--sgl-endpoint", type=str)
|
232 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
233 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
234 |
+
parser.add_argument("--no-register", action="store_true")
|
235 |
+
args = parser.parse_args()
|
236 |
+
logger.info(f"args: {args}")
|
237 |
+
|
238 |
+
worker = ModelWorker(args.controller_address,
|
239 |
+
args.worker_address,
|
240 |
+
args.sgl_endpoint,
|
241 |
+
worker_id,
|
242 |
+
args.no_register,
|
243 |
+
args.model_name)
|
244 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
test_message.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
from llava.conversation import default_conversation
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
if args.worker_address:
|
11 |
+
worker_addr = args.worker_address
|
12 |
+
else:
|
13 |
+
controller_addr = args.controller_address
|
14 |
+
ret = requests.post(controller_addr + "/refresh_all_workers")
|
15 |
+
ret = requests.post(controller_addr + "/list_models")
|
16 |
+
models = ret.json()["models"]
|
17 |
+
models.sort()
|
18 |
+
print(f"Models: {models}")
|
19 |
+
|
20 |
+
ret = requests.post(controller_addr + "/get_worker_address",
|
21 |
+
json={"model": args.model_name})
|
22 |
+
worker_addr = ret.json()["address"]
|
23 |
+
print(f"worker_addr: {worker_addr}")
|
24 |
+
|
25 |
+
if worker_addr == "":
|
26 |
+
return
|
27 |
+
|
28 |
+
conv = default_conversation.copy()
|
29 |
+
conv.append_message(conv.roles[0], args.message)
|
30 |
+
prompt = conv.get_prompt()
|
31 |
+
|
32 |
+
headers = {"User-Agent": "LLaVA Client"}
|
33 |
+
pload = {
|
34 |
+
"model": args.model_name,
|
35 |
+
"prompt": prompt,
|
36 |
+
"max_new_tokens": args.max_new_tokens,
|
37 |
+
"temperature": 0.7,
|
38 |
+
"stop": conv.sep,
|
39 |
+
}
|
40 |
+
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
|
41 |
+
json=pload, stream=True)
|
42 |
+
|
43 |
+
print(prompt.replace(conv.sep, "\n"), end="")
|
44 |
+
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
45 |
+
if chunk:
|
46 |
+
data = json.loads(chunk.decode("utf-8"))
|
47 |
+
output = data["text"].split(conv.sep)[-1]
|
48 |
+
print(output, end="\r")
|
49 |
+
print("")
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
55 |
+
parser.add_argument("--worker-address", type=str)
|
56 |
+
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
57 |
+
parser.add_argument("--max-new-tokens", type=int, default=32)
|
58 |
+
parser.add_argument("--message", type=str, default=
|
59 |
+
"Tell me a story with more than 1000 words.")
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
main()
|