tlvtech commited on
Commit
f8178ae
1 Parent(s): 8404742

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/1034346401.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/sample_demo_1.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Serve
3
- emoji:
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: serve
3
+ app_file: gradio_web_server_adhoc.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.50.0
 
 
6
  ---
 
 
cli.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from videollama2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, NUM_FRAMES
5
+ from videollama2.conversation import conv_templates, SeparatorStyle
6
+ from videollama2.model.builder import load_pretrained_model
7
+ from videollama2.utils import disable_torch_init
8
+ from videollama2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, tokenizer_MMODAL_token
9
+
10
+ from PIL import Image
11
+ from decord import VideoReader, cpu
12
+
13
+ import requests
14
+ from io import BytesIO
15
+ from transformers import TextStreamer
16
+
17
+
18
+ def load_image(image_file):
19
+ if image_file.startswith('http://') or image_file.startswith('https://'):
20
+ response = requests.get(image_file)
21
+ image = Image.open(BytesIO(response.content)).convert('RGB')
22
+ else:
23
+ image = Image.open(image_file).convert('RGB')
24
+ return image
25
+
26
+ def load_video(video_file):
27
+ decord_vr = VideoReader(uri=video_file, ctx=cpu(0))
28
+ duration = len(decord_vr)
29
+ frame_id_list = np.linspace(0, duration-1, NUM_FRAMES, dtype=int)
30
+ video = decord_vr.get_batch(frame_id_list)
31
+ return video
32
+
33
+ def load_image_or_video(image_or_video_file):
34
+ if file_path.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
35
+ return load_image(image_file=image_or_video_file)
36
+ elif file_path.endswith(('.mp4', '.avi', '.mov')):
37
+ return load_video(video_file=image_or_video_file)
38
+ else:
39
+ raise Exception(f"File type of {image_or_video_file} not supported!!!")
40
+
41
+
42
+ def main(args):
43
+ # Model
44
+ disable_torch_init()
45
+
46
+ model_name = get_model_name_from_path(args.model_path)
47
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
48
+
49
+ # if "llama-2" in model_name.lower():
50
+ # conv_mode = "llava_llama2"
51
+ # elif "mistral" in model_name.lower():
52
+ # conv_mode = "mistral"
53
+ # elif "v1.6-34b" in model_name.lower():
54
+ # conv_mode = "chatml_direct"
55
+ # elif "v1" in model_name.lower():
56
+ # conv_mode = "llava_v1"
57
+ # else:
58
+ # conv_mode = "llava_v0"
59
+ conv_mode = "llava_v1" # fix conversation mode for now
60
+
61
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
62
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
63
+ else:
64
+ args.conv_mode = conv_mode
65
+
66
+ conv = conv_templates[args.conv_mode].copy()
67
+ roles = conv.roles
68
+
69
+ image = load_image(args.image_file)
70
+ image_size = image.size
71
+ # Similar operation in model_worker.py
72
+ image_tensor = process_images([image], image_processor, model.config)
73
+ if type(image_tensor) is list:
74
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
75
+ else:
76
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
77
+
78
+ while True:
79
+ try:
80
+ inp = input(f"{roles[0]}: ")
81
+ except EOFError:
82
+ inp = ""
83
+ if not inp:
84
+ print("exit...")
85
+ break
86
+
87
+ print(f"{roles[1]}: ", end="")
88
+
89
+ if image is not None:
90
+ # first message
91
+ if model.config.mm_use_im_start_end:
92
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
93
+ else:
94
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
95
+ conv.append_message(conv.roles[0], inp)
96
+ image = None
97
+ else:
98
+ # later messages
99
+ conv.append_message(conv.roles[0], inp)
100
+ conv.append_message(conv.roles[1], None)
101
+ prompt = conv.get_prompt()
102
+
103
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
104
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
105
+ keywords = [stop_str]
106
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
107
+
108
+ with torch.inference_mode():
109
+ output_ids = model.generate(
110
+ input_ids,
111
+ images=image_tensor,
112
+ image_sizes=[image_size],
113
+ do_sample=True if args.temperature > 0 else False,
114
+ temperature=args.temperature,
115
+ max_new_tokens=args.max_new_tokens,
116
+ streamer=streamer,
117
+ use_cache=True)
118
+
119
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
120
+ conv.messages[-1][-1] = outputs
121
+
122
+ if args.debug:
123
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
129
+ parser.add_argument("--model-base", type=str, default=None)
130
+ parser.add_argument("--image-file", type=str, required=True)
131
+ parser.add_argument("--device", type=str, default="cuda")
132
+ parser.add_argument("--conv-mode", type=str, default=None)
133
+ parser.add_argument("--temperature", type=float, default=0.2)
134
+ parser.add_argument("--max-new-tokens", type=int, default=512)
135
+ parser.add_argument("--load-8bit", action="store_true")
136
+ parser.add_argument("--load-4bit", action="store_true")
137
+ parser.add_argument("--debug", action="store_true")
138
+ args = parser.parse_args()
139
+ main(args)
controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from videollama2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from videollama2.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,), daemon=True)
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
examples/1034346401.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b62a634fe49edc0a19fc53f6ea5cfb345d9b2a6a7047811344c16832dc42b2
3
+ size 1678095
examples/desert.jpg ADDED
examples/extreme_ironing.jpg ADDED
examples/sample_demo_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc6562a172eb9cb3c760a3c9992349c1faa2c793c112b7b9e50bd5cb17c2164d
3
+ size 1549315
examples/sample_demo_3.mp4 ADDED
Binary file (464 kB). View file
 
examples/sample_demo_9.mp4 ADDED
Binary file (632 kB). View file
 
examples/waterview.jpg ADDED
gradio_web_server.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import hashlib
5
+ import requests
6
+ import argparse
7
+ import datetime
8
+
9
+ import numpy as np
10
+ import gradio as gr
11
+ from decord import VideoReader, cpu
12
+
13
+ from videollama2.constants import LOGDIR, NUM_FRAMES
14
+ from videollama2.conversation import (default_conversation, conv_templates,SeparatorStyle)
15
+ from videollama2.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
16
+
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "Videollama2 Client"}
21
+
22
+ no_change_btn = gr.Button.update()
23
+ enable_btn = gr.Button.update(interactive=True)
24
+ disable_btn = gr.Button.update(interactive=False)
25
+
26
+ priority = {
27
+ "vicuna-13b": "aaaaaaa",
28
+ "koala-13b": "aaaaaab",
29
+ }
30
+
31
+
32
+ def get_conv_log_filename():
33
+ t = datetime.datetime.now()
34
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
+ return name
36
+
37
+
38
+ def get_model_list():
39
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
40
+ assert ret.status_code == 200
41
+ ret = requests.post(args.controller_url + "/list_models")
42
+ models = ret.json()["models"]
43
+ models.sort(key=lambda x: priority.get(x, x))
44
+ logger.info(f"Models: {models}")
45
+ return models
46
+
47
+
48
+ get_window_url_params = """
49
+ function() {
50
+ const params = new URLSearchParams(window.location.search);
51
+ url_params = Object.fromEntries(params);
52
+ console.log(url_params);
53
+ return url_params;
54
+ }
55
+ """
56
+
57
+
58
+ def load_demo(url_params, request: gr.Request):
59
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
+
61
+ dropdown_update = gr.Dropdown.update(visible=True)
62
+ if "model" in url_params:
63
+ model = url_params["model"]
64
+ if model in models:
65
+ dropdown_update = gr.Dropdown.update(
66
+ value=model, visible=True)
67
+
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+
72
+ def load_demo_refresh_model_list(request: gr.Request):
73
+ logger.info(f"load_demo. ip: {request.client.host}")
74
+ models = get_model_list()
75
+ state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown.update(
77
+ choices=models,
78
+ value=models[0] if len(models) > 0 else ""
79
+ )
80
+ return state, dropdown_update
81
+
82
+
83
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
+ with open(get_conv_log_filename(), "a") as fout:
85
+ data = {
86
+ "tstamp": round(time.time(), 4),
87
+ "type": vote_type,
88
+ "model": model_selector,
89
+ "state": state.dict(),
90
+ "ip": request.client.host,
91
+ }
92
+ fout.write(json.dumps(data) + "\n")
93
+
94
+
95
+ def upvote_last_response(state, model_selector, request: gr.Request):
96
+ logger.info(f"upvote. ip: {request.client.host}")
97
+ vote_last_response(state, "upvote", model_selector, request)
98
+ return ("",) + (disable_btn,) * 3
99
+
100
+
101
+ def downvote_last_response(state, model_selector, request: gr.Request):
102
+ logger.info(f"downvote. ip: {request.client.host}")
103
+ vote_last_response(state, "downvote", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ def flag_last_response(state, model_selector, request: gr.Request):
108
+ logger.info(f"flag. ip: {request.client.host}")
109
+ vote_last_response(state, "flag", model_selector, request)
110
+ return ("",) + (disable_btn,) * 3
111
+
112
+
113
+ def regenerate(state, image_process_mode, request: gr.Request):
114
+ logger.info(f"regenerate. ip: {request.client.host}")
115
+ state.messages[-1][-1] = None
116
+ prev_human_msg = state.messages[-2]
117
+ if type(prev_human_msg[1]) in (tuple, list):
118
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
+ state.skip_next = False
120
+ # (state, chatbot, textbox, imagebox, videobox, upvote, downvote, flag, generate, clear)
121
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
122
+
123
+
124
+ def clear_history(request: gr.Request):
125
+ logger.info(f"clear_history. ip: {request.client.host}")
126
+ state = default_conversation.copy()
127
+ # (state, chatbot, textbox, imagebox, videobox, upvote, downvote, flag, generate, clear)
128
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
129
+
130
+
131
+ def add_text_ori(state, text, image, video, image_process_mode, request: gr.Request):
132
+ # note: imagebox itself is PIL object while videobox is filepath
133
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
134
+ if len(text) <= 0 and image is None:
135
+ state.skip_next = True
136
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
137
+ if args.moderate:
138
+ flagged = violates_moderation(text)
139
+ if flagged:
140
+ state.skip_next = True
141
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
142
+ no_change_btn,) * 5
143
+ assert image is None or video is None, "Please don't feed image and video inputs at the same time!!!"
144
+ text = text[:1536] # Hard cut-off
145
+ if image is not None:
146
+ # here image is the PIL object itself
147
+ text = text[:1200] # Hard cut-off for images
148
+ if '<image>' not in text:
149
+ # text = '<Image><image></Image>' + text
150
+ text = text + '\n<image>'
151
+ text = (text, image, image_process_mode)
152
+ if len(state.get_images(return_pil=True)) > 0:
153
+ state = default_conversation.copy()
154
+ state.modality = "image"
155
+ if video is not None:
156
+ print("Video box:", video)
157
+ # here video is the file path of video
158
+ text = text[:1200] # Hard cut-off for images
159
+ if '<video>' not in text:
160
+ # text = '<Image><image></Image>' + text
161
+ text = text + '\n<video>'
162
+ text = (text, video, image_process_mode)
163
+ if len(state.get_videos(return_pil=True)) > 0:
164
+ state = default_conversation.copy()
165
+ state.modality = "video"
166
+ print("Set modality as video...")
167
+ state.append_message(state.roles[0], text)
168
+ state.append_message(state.roles[1], None)
169
+ state.skip_next = False
170
+ # (state, chatbot, textbox, imagebox, videobox, upvote, downvote, flag, generate, clear)
171
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
172
+
173
+
174
+ def add_text(state, text, image, video, image_process_mode, request: gr.Request):
175
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
176
+
177
+ # if input is new video or image ,reset the state
178
+ if image is not None or video is not None:
179
+ state = default_conversation.copy()
180
+
181
+ if len(text) <= 0 and image is None and video is None:
182
+ state.skip_next = True
183
+ return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
184
+
185
+ if args.moderate:
186
+ flagged = violates_moderation(text)
187
+ if flagged:
188
+ state.skip_next = True
189
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
190
+
191
+ # process the input video
192
+ if video is not None:
193
+ text = text[:1200] #
194
+ if '<video>' not in text:
195
+ text = text + '\n<video>'
196
+ text = (text, video, image_process_mode)
197
+ state.modality = "video"
198
+ # process the input image
199
+ elif image is not None:
200
+ text = text[:1200] #
201
+ if '<image>' not in text:
202
+ text = text + '\n<image>'
203
+ text = (text, image, image_process_mode)
204
+ state.modality = "image"
205
+ elif state.modality == "image" and len(text)>0:
206
+ state.modality = "image_text"
207
+ text = text[:1536] # Hard cut-off
208
+ elif state.modality == "video" and len(text)>0:
209
+ state.modality = "video_text"
210
+ text = text[:1536] # Hard cut-off
211
+
212
+ state.append_message(state.roles[0], text)
213
+ state.append_message(state.roles[1], None)
214
+ state.skip_next = False
215
+
216
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
217
+
218
+
219
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
220
+ logger.info(f"http_bot. ip: {request.client.host}")
221
+ start_tstamp = time.time()
222
+ model_name = model_selector
223
+
224
+ if state.skip_next:
225
+ # This generate call is skipped due to invalid inputs
226
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
227
+ return
228
+
229
+ if len(state.messages) == state.offset + 2:
230
+ # First round of conversation
231
+ if "llava" in model_name.lower():
232
+ if 'llama-2' in model_name.lower():
233
+ template_name = "llava_llama2"
234
+ elif "v1" in model_name.lower():
235
+ if 'mmtag' in model_name.lower():
236
+ template_name = "v1_mmtag"
237
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
238
+ template_name = "v1_mmtag"
239
+ else:
240
+ template_name = "llava_v1"
241
+ else:
242
+ if 'mmtag' in model_name.lower():
243
+ template_name = "v0_mmtag"
244
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
245
+ template_name = "v0_mmtag"
246
+ else:
247
+ template_name = "llava_v0"
248
+ elif "llama-2" in model_name:
249
+ template_name = "llama2"
250
+ else:
251
+ template_name = "vicuna_v1"
252
+ template_name = "llava_v1"
253
+ new_state = conv_templates[template_name].copy()
254
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
255
+ new_state.append_message(new_state.roles[1], None)
256
+ new_state.modality = state.modality
257
+ state = new_state
258
+
259
+ # Query worker address
260
+ controller_url = args.controller_url
261
+ ret = requests.post(controller_url + "/get_worker_address",
262
+ json={"model": model_name})
263
+ worker_addr = ret.json()["address"]
264
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
265
+
266
+ # No available worker
267
+ if worker_addr == "":
268
+ state.messages[-1][-1] = server_error_msg
269
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
270
+ return
271
+
272
+ # Construct prompt
273
+ prompt = state.get_prompt()
274
+ if state.modality == "image" or state.modality == "image_text":
275
+ all_images = state.get_images(return_pil=True) # return PIL.Image object
276
+ elif state.modality == "video" or state.modality == "video_text":
277
+ all_images = state.get_videos(return_pil=True) # return video frames where each frame is a PIL.Image object
278
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
279
+ for idx, (image, hash) in enumerate(zip(all_images, all_image_hash)):
280
+ t = datetime.datetime.now()
281
+ if state.modality == "image" or state.modality == "image_text":
282
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
283
+ elif state.modality == "video" or state.modality == "video_text":
284
+ filename = os.path.join(LOGDIR, "serve_videos", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}_{idx}.jpg")
285
+ if not os.path.isfile(filename):
286
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
287
+ image.save(filename)
288
+
289
+ # Make requests
290
+ pload = {
291
+ "model": model_name,
292
+ "prompt": prompt,
293
+ "temperature": float(temperature),
294
+ "top_p": float(top_p),
295
+ "max_new_tokens": min(int(max_new_tokens), 1536),
296
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE] else state.sep2,
297
+ #"images": f'List of {len(state.get_images())} images: {all_image_hash}',
298
+ "images": f'List of {len(all_image_hash)} images: {all_image_hash}',
299
+ }
300
+ logger.info(f"==== request ====\n{pload}")
301
+
302
+ if state.modality == "image" or state.modality == "image_text":
303
+ pload['images'] = state.get_images()
304
+ elif state.modality == "video" or state.modality == "video_text":
305
+ pload['images'] = state.get_videos()
306
+
307
+ state.messages[-1][-1] = "▌"
308
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
309
+
310
+ try:
311
+ # Stream output
312
+ response = requests.post(worker_addr + "/worker_generate_stream",
313
+ headers=headers, json=pload, stream=True, timeout=10)
314
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
315
+ if chunk:
316
+ data = json.loads(chunk.decode())
317
+ if data["error_code"] == 0:
318
+ output = data["text"][len(prompt):].strip()
319
+ state.messages[-1][-1] = output + "▌"
320
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
321
+ else:
322
+ output = data["text"] + f" (error_code: {data['error_code']})"
323
+ state.messages[-1][-1] = output
324
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
325
+ return
326
+ time.sleep(0.03)
327
+ except requests.exceptions.RequestException as e:
328
+ state.messages[-1][-1] = server_error_msg
329
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
330
+ return
331
+
332
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
333
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
334
+
335
+ finish_tstamp = time.time()
336
+ logger.info(f"{output}")
337
+
338
+ with open(get_conv_log_filename(), "a") as fout:
339
+ data = {
340
+ "tstamp": round(finish_tstamp, 4),
341
+ "type": "chat",
342
+ "model": model_name,
343
+ "start": round(start_tstamp, 4),
344
+ "finish": round(start_tstamp, 4),
345
+ #"state": state.dict(),
346
+ "images": all_image_hash,
347
+ "ip": request.client.host,
348
+ }
349
+ fout.write(json.dumps(data) + "\n")
350
+
351
+ title_markdown = ("""
352
+ # The publicl release of VideoLLaMA2
353
+ """)
354
+
355
+ tos_markdown = ("""
356
+ ### Terms of use
357
+ By using this service, users are required to agree to the following terms:
358
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
359
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
360
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
361
+ """)
362
+
363
+
364
+ learn_more_markdown = ("""
365
+ ### License
366
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
367
+ """)
368
+
369
+ block_css = """
370
+
371
+ #buttons button {
372
+ min-width: min(120px,100%);
373
+ }
374
+
375
+ """
376
+
377
+ def build_demo(embed_mode):
378
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
379
+ with gr.Blocks(title="Video-Llama", theme=gr.themes.Default(), css=block_css) as demo:
380
+ state = gr.State()
381
+
382
+ if not embed_mode:
383
+ gr.Markdown(title_markdown)
384
+
385
+ with gr.Row():
386
+ with gr.Column(scale=3):
387
+ with gr.Row(elem_id="model_selector_row"):
388
+ model_selector = gr.Dropdown(
389
+ choices=models,
390
+ value=models[0] if len(models) > 0 else "",
391
+ interactive=True,
392
+ show_label=False,
393
+ container=False)
394
+
395
+ imagebox = gr.Image(type="pil")
396
+ videobox = gr.Video()
397
+ image_process_mode = gr.Radio(
398
+ ["Crop", "Resize", "Pad", "Default"],
399
+ value="Default",
400
+ label="Preprocess for non-square image", visible=False)
401
+
402
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
403
+ gr.Examples(examples=[
404
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
405
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
406
+ [f"{cur_dir}/examples/desert.jpg", "If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?"],
407
+ ], inputs=[imagebox, textbox], label="Image examples")
408
+
409
+ # video example inputs
410
+ gr.Examples(examples=[
411
+ [f"{cur_dir}/examples/sample_demo_1.mp4", "Why is this video funny?"],
412
+ [f"{cur_dir}/examples/sample_demo_3.mp4", "Can you identify any safety hazards in this video?"],
413
+ [f"{cur_dir}/examples/1034346401.mp4", "What is this young woman doing?"]
414
+ ], inputs=[videobox, textbox], label="Video examples")
415
+ #[f"{cur_dir}/examples/sample_demo_9.mp4", "Describe the video in detail and please do not generate repetitive content."]
416
+
417
+ with gr.Accordion("Parameters", open=False) as parameter_row:
418
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
419
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
420
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
421
+
422
+ with gr.Column(scale=8):
423
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Videollama2 Chatbot", height=550)
424
+ with gr.Row():
425
+ with gr.Column(scale=8):
426
+ textbox.render()
427
+ with gr.Column(scale=1, min_width=50):
428
+ submit_btn = gr.Button(value="Send", variant="primary")
429
+ with gr.Row(elem_id="buttons") as button_row:
430
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
431
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
432
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
433
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
434
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
435
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
436
+
437
+ if not embed_mode:
438
+ gr.Markdown(tos_markdown)
439
+ gr.Markdown(learn_more_markdown)
440
+ url_params = gr.JSON(visible=False)
441
+
442
+ # Register listeners
443
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
444
+ upvote_btn.click(upvote_last_response,
445
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
446
+ downvote_btn.click(downvote_last_response,
447
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
448
+ flag_btn.click(flag_last_response,
449
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
450
+ regenerate_btn.click(regenerate, [state, image_process_mode],
451
+ [state, chatbot, textbox, imagebox, videobox] + btn_list).then(
452
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
453
+ [state, chatbot] + btn_list)
454
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, videobox] + btn_list)
455
+
456
+ textbox.submit(add_text, [state, textbox, imagebox, videobox, image_process_mode], [state, chatbot, textbox, imagebox, videobox] + btn_list
457
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
458
+ [state, chatbot] + btn_list)
459
+ submit_btn.click(add_text, [state, textbox, imagebox, videobox, image_process_mode], [state, chatbot, textbox, imagebox, videobox] + btn_list
460
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
461
+ [state, chatbot] + btn_list)
462
+
463
+ if args.model_list_mode == "once":
464
+ demo.load(load_demo, [url_params], [state, model_selector],
465
+ _js=get_window_url_params)
466
+ elif args.model_list_mode == "reload":
467
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector])
468
+ else:
469
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
470
+
471
+ return demo
472
+
473
+
474
+ if __name__ == "__main__":
475
+ parser = argparse.ArgumentParser()
476
+ parser.add_argument("--host", type=str, default="0.0.0.0")
477
+ parser.add_argument("--port", type=int)
478
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
479
+ parser.add_argument("--concurrency-count", type=int, default=10)
480
+ parser.add_argument("--model-list-mode", type=str, default="once",
481
+ choices=["once", "reload"])
482
+ parser.add_argument("--share", action="store_true")
483
+ parser.add_argument("--moderate", action="store_true")
484
+ parser.add_argument("--embed", action="store_true")
485
+ args = parser.parse_args()
486
+ logger.info(f"args: {args}")
487
+
488
+ models = get_model_list()
489
+
490
+ logger.info(args)
491
+ demo = build_demo(args.embed)
492
+ demo.queue(
493
+ concurrency_count=args.concurrency_count,
494
+ api_open=False
495
+ ).launch(
496
+ server_name=args.host,
497
+ server_port=args.port,
498
+ share=args.share
499
+ )
gradio_web_server_adhoc.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+
3
+ import os
4
+ import re
5
+
6
+ import torch
7
+ import gradio as gr
8
+
9
+ import sys
10
+ sys.path.append('./')
11
+ from videollama2 import model_init, mm_infer
12
+ from videollama2.utils import disable_torch_init
13
+
14
+
15
+ title_markdown = ("""
16
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
17
+ <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
18
+ <img src="https://s2.loli.net/2024/06/03/D3NeXHWy5az9tmT.png" alt="VideoLLaMA 2 🔥🚀🔥" style="max-width: 120px; height: auto;">
19
+ </a>
20
+ <div>
21
+ <h1 >VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs</h1>
22
+ <h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
23
+ </div>
24
+ </div>
25
+
26
+
27
+ <div align="center">
28
+ <div style="display:flex; gap: 0.25rem; margin-top: 10px;" align="center">
29
+ <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2"><img src='https://img.shields.io/badge/Github-VideoLLaMA2-9C276A'></a>
30
+ <a href="https://arxiv.org/pdf/2406.07476.pdf"><img src="https://img.shields.io/badge/Arxiv-2406.07476-AD1C18"></a>
31
+ <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA2/stargazers"><img src="https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA2.svg?style=social"></a>
32
+ </div>
33
+ </div>
34
+ """)
35
+
36
+
37
+ block_css = """
38
+ #buttons button {
39
+ min-width: min(120px,100%);
40
+ color: #9C276A
41
+ }
42
+ """
43
+
44
+
45
+ tos_markdown = ("""
46
+ ### Terms of use
47
+ By using this service, users are required to agree to the following terms:
48
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
49
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
50
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
51
+ """)
52
+
53
+
54
+ learn_more_markdown = ("""
55
+ ### License
56
+ This project is released under the Apache 2.0 license as found in the LICENSE file. The service is a research preview intended for non-commercial use ONLY, subject to the model Licenses of LLaMA and Mistral, Terms of Use of the data generated by OpenAI, and Privacy Practices of ShareGPT. Please get in touch with us if you find any potential violations.
57
+ """)
58
+
59
+
60
+ plum_color = gr.themes.colors.Color(
61
+ name='plum',
62
+ c50='#F8E4EF',
63
+ c100='#E9D0DE',
64
+ c200='#DABCCD',
65
+ c300='#CBA8BC',
66
+ c400='#BC94AB',
67
+ c500='#AD809A',
68
+ c600='#9E6C89',
69
+ c700='#8F5878',
70
+ c800='#804467',
71
+ c900='#713056',
72
+ c950='#662647',
73
+ )
74
+
75
+
76
+ class Chat:
77
+
78
+ def __init__(self, model_path, load_8bit=False, load_4bit=False):
79
+ disable_torch_init()
80
+
81
+ self.model, self.processor, self.tokenizer = model_init(model_path, load_8bit=load_8bit, load_4bit=load_4bit)
82
+
83
+ @spaces.GPU(duration=120)
84
+ @torch.inference_mode()
85
+ def generate(self, data: list, message, temperature, top_p, max_output_tokens):
86
+ # TODO: support multiple turns of conversation.
87
+ assert len(data) == 1
88
+
89
+ tensor, modal = data[0]
90
+ response = mm_infer(tensor, message, self.model, self.tokenizer, modal=modal.strip('<>'),
91
+ do_sample=True if temperature > 0.0 else False,
92
+ temperature=temperature,
93
+ top_p=top_p,
94
+ max_new_tokens=max_output_tokens)
95
+
96
+ return response
97
+
98
+
99
+ @spaces.GPU(duration=120)
100
+ def generate(image, video, message, chatbot, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
101
+ data = []
102
+
103
+ processor = handler.processor
104
+ try:
105
+ if image is not None:
106
+ data.append((processor['image'](image).to(handler.model.device, dtype=dtype), '<image>'))
107
+ elif video is not None:
108
+ data.append((processor['video'](video).to(handler.model.device, dtype=dtype), '<video>'))
109
+ elif image is None and video is None:
110
+ data.append((None, '<text>'))
111
+ else:
112
+ raise NotImplementedError("Not support image and video at the same time")
113
+ except Exception as e:
114
+ traceback.print_exc()
115
+ return gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), message, chatbot
116
+
117
+ assert len(message) % 2 == 0, "The message should be a pair of user and system message."
118
+
119
+ show_images = ""
120
+ if image is not None:
121
+ show_images += f'<img src="./file={image}" style="display: inline-block;width: 250px;max-height: 400px;">'
122
+ if video is not None:
123
+ show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
124
+
125
+ one_turn_chat = [textbox_in, None]
126
+
127
+ # 1. first run case
128
+ if len(chatbot) == 0:
129
+ one_turn_chat[0] += "\n" + show_images
130
+ # 2. not first run case
131
+ else:
132
+ # scanning the last image or video
133
+ length = len(chatbot)
134
+ for i in range(length - 1, -1, -1):
135
+ previous_image = re.findall(r'<img src="./file=(.+?)"', chatbot[i][0])
136
+ previous_video = re.findall(r'<video controls playsinline width="500" style="display: inline-block;" src="./file=(.+?)"', chatbot[i][0])
137
+
138
+ if len(previous_image) > 0:
139
+ previous_image = previous_image[-1]
140
+ # 2.1 new image append or pure text input will start a new conversation
141
+ if (video is not None) or (image is not None and os.path.basename(previous_image) != os.path.basename(image)):
142
+ message.clear()
143
+ one_turn_chat[0] += "\n" + show_images
144
+ break
145
+ elif len(previous_video) > 0:
146
+ previous_video = previous_video[-1]
147
+ # 2.2 new video append or pure text input will start a new conversation
148
+ if image is not None or (video is not None and os.path.basename(previous_video) != os.path.basename(video)):
149
+ message.clear()
150
+ one_turn_chat[0] += "\n" + show_images
151
+ break
152
+
153
+ message.append({'role': 'user', 'content': textbox_in})
154
+ text_en_out = handler.generate(data, message, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens)
155
+ message.append({'role': 'assistant', 'content': text_en_out})
156
+
157
+ one_turn_chat[1] = text_en_out
158
+ chatbot.append(one_turn_chat)
159
+
160
+ return gr.update(value=image, interactive=True), gr.update(value=video, interactive=True), message, chatbot
161
+
162
+
163
+ def regenerate(message, chatbot):
164
+ message.pop(-1), message.pop(-1)
165
+ chatbot.pop(-1)
166
+ return message, chatbot
167
+
168
+
169
+ def clear_history(message, chatbot):
170
+ message.clear(), chatbot.clear()
171
+ return (gr.update(value=None, interactive=True),
172
+ gr.update(value=None, interactive=True),
173
+ message, chatbot,
174
+ gr.update(value=None, interactive=True))
175
+
176
+
177
+ # BUG of Zero Environment
178
+ # 1. The environment is fixed to torch>=2.0,<=2.2, gradio>=4.x.x
179
+ # 2. The operation or tensor which requires cuda are limited in those functions wrapped via spaces.GPU
180
+ # 3. The function can't return tensor or other cuda objects.
181
+
182
+ model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B-16F'
183
+
184
+ handler = Chat(model_path, load_8bit=False, load_4bit=True)
185
+
186
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
187
+
188
+ theme = gr.themes.Default(primary_hue=plum_color)
189
+ # theme.update_color("primary", plum_color.c500)
190
+ theme.set(slider_color="#9C276A")
191
+ theme.set(block_title_text_color="#9C276A")
192
+ theme.set(block_label_text_color="#9C276A")
193
+ theme.set(button_primary_text_color="#9C276A")
194
+ # theme.set(button_secondary_text_color="*neutral_800")
195
+
196
+
197
+ with gr.Blocks(title='VideoLLaMA 2 🔥🚀🔥', theme=theme, css=block_css) as demo:
198
+ gr.Markdown(title_markdown)
199
+ message = gr.State([])
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=3):
203
+ image = gr.Image(label="Input Image", type="filepath")
204
+ video = gr.Video(label="Input Video")
205
+
206
+ with gr.Accordion("Parameters", open=True) as parameter_row:
207
+ # num_beams = gr.Slider(
208
+ # minimum=1,
209
+ # maximum=10,
210
+ # value=1,
211
+ # step=1,
212
+ # interactive=True,
213
+ # label="beam search numbers",
214
+ # )
215
+
216
+ temperature = gr.Slider(
217
+ minimum=0.1,
218
+ maximum=1.0,
219
+ value=0.2,
220
+ step=0.1,
221
+ interactive=True,
222
+ label="Temperature",
223
+ )
224
+
225
+ top_p = gr.Slider(
226
+ minimum=0.0,
227
+ maximum=1.0,
228
+ value=0.7,
229
+ step=0.1,
230
+ interactive=True,
231
+ label="Top P",
232
+ )
233
+
234
+ max_output_tokens = gr.Slider(
235
+ minimum=64,
236
+ maximum=1024,
237
+ value=512,
238
+ step=64,
239
+ interactive=True,
240
+ label="Max output tokens",
241
+ )
242
+
243
+ with gr.Column(scale=7):
244
+ chatbot = gr.Chatbot(label="VideoLLaMA 2", bubble_full_width=True, height=750)
245
+ with gr.Row():
246
+ with gr.Column(scale=8):
247
+ textbox.render()
248
+ with gr.Column(scale=1, min_width=50):
249
+ submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
250
+ with gr.Row(elem_id="buttons") as button_row:
251
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
252
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
253
+ # flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
254
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
255
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
256
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
257
+
258
+ with gr.Row():
259
+ with gr.Column():
260
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
261
+ gr.Examples(
262
+ examples=[
263
+ [
264
+ f"{cur_dir}/examples/extreme_ironing.jpg",
265
+ "What happens in this image?",
266
+ ],
267
+ [
268
+ f"{cur_dir}/examples/waterview.jpg",
269
+ "What are the things I should be cautious about when I visit here?",
270
+ ],
271
+ [
272
+ f"{cur_dir}/examples/desert.jpg",
273
+ "If there are factual errors in the questions, point it out; if not, proceed answering the question. What’s happening in the desert?",
274
+ ],
275
+ ],
276
+ inputs=[image, textbox],
277
+ )
278
+ with gr.Column():
279
+ gr.Examples(
280
+ examples=[
281
+ [
282
+ f"{cur_dir}/../../assets/cat_and_chicken.mp4",
283
+ "What happens in this video?",
284
+ ],
285
+ [
286
+ f"{cur_dir}/../../assets/sora.mp4",
287
+ "Please describe this video.",
288
+ ],
289
+ [
290
+ f"{cur_dir}/examples/sample_demo_1.mp4",
291
+ "What does the baby do?",
292
+ ],
293
+ ],
294
+ inputs=[video, textbox],
295
+ )
296
+
297
+ gr.Markdown(tos_markdown)
298
+ gr.Markdown(learn_more_markdown)
299
+
300
+ submit_btn.click(
301
+ generate,
302
+ [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
303
+ [image, video, message, chatbot])
304
+
305
+ regenerate_btn.click(
306
+ regenerate,
307
+ [message, chatbot],
308
+ [message, chatbot]).then(
309
+ generate,
310
+ [image, video, message, chatbot, textbox, temperature, top_p, max_output_tokens],
311
+ [image, video, message, chatbot])
312
+
313
+ clear_btn.click(
314
+ clear_history,
315
+ [message, chatbot],
316
+ [image, video, message, chatbot, textbox])
317
+
318
+ demo.launch(share = True)
model_worker.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import os
5
+ import json
6
+ import time
7
+ import uuid
8
+ import asyncio
9
+ import requests
10
+ import argparse
11
+ import threading
12
+ from threading import Thread
13
+ from functools import partial
14
+ from typing import Iterator, List, Optional, Tuple
15
+
16
+ import uvicorn
17
+ from fastapi import FastAPI, Request, BackgroundTasks
18
+ from fastapi.responses import StreamingResponse
19
+
20
+ import torch
21
+ import decord
22
+ import numpy as np
23
+ from PIL import Image
24
+ from decord import VideoReader, cpu
25
+ from transformers import TextIteratorStreamer
26
+
27
+ from videollama2.constants import WORKER_HEART_BEAT_INTERVAL
28
+ from videollama2.utils import (build_logger, server_error_msg, pretty_print_semaphore)
29
+ from videollama2.model.builder import load_pretrained_model
30
+ from videollama2.mm_utils import process_images, process_videos, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria, tokenizer_MMODAL_token
31
+ from videollama2.mm_utils import chunk_list, frame_expansion
32
+ from videollama2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_TOKEN, NUM_FRAMES, MMODAL_TOKEN_INDEX
33
+
34
+
35
+ GB = 1 << 30
36
+
37
+ worker_id = str(uuid.uuid4())[:6]
38
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
39
+ global_counter = 0
40
+
41
+ model_semaphore = None
42
+
43
+
44
+ # variable_content = os.getenv('MY_VARIABLE', '')
45
+ # KEYWORDS_LIST = set(variable_content.split('\n'))
46
+ KEYWORDS_LIST = []
47
+ path = 'assets/keywords.txt'
48
+ if os.path.exists(path):
49
+ with open(path, 'r', encoding='utf-8') as file:
50
+ for line in file:
51
+
52
+ KEYWORDS_LIST.append(line.strip())
53
+ else:
54
+ KEYWORDS_LIST = []
55
+
56
+
57
+ KEYWORD_BLOCK_MESSAGE2 = "The output contains political, erotic and other unsafe content that violates local laws. Please re-enter your question."
58
+ KEYWORD_BLOCK_MESSAGE1 = "Your input question contains political, erotic and other unsafe content that violates local laws. Please re-enter your question."
59
+ STREAM_CHECK_MULTIPLE = 20
60
+
61
+
62
+ def heart_beat_worker(controller):
63
+
64
+ while True:
65
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
66
+ controller.send_heart_beat()
67
+
68
+
69
+ def safety_check(text, history=None, ) -> Optional[str]:
70
+
71
+ if len(KEYWORDS_LIST) > 0 and any(x in text.lower() for x in KEYWORDS_LIST):
72
+ print('############')
73
+ return KEYWORD_BLOCK_MESSAGE2
74
+
75
+ return None
76
+
77
+
78
+ def input_safety_check(text) -> Optional[str]:
79
+ if len(KEYWORDS_LIST) > 0 and any(x in text.lower() for x in KEYWORDS_LIST):
80
+ print('######## Input keyword alarm triggered:', text)
81
+ return KEYWORD_BLOCK_MESSAGE1
82
+ return None
83
+
84
+
85
+ class ModelWorker:
86
+
87
+ def __init__(self, controller_addr, worker_addr,
88
+ worker_id, no_register,
89
+ model_path, model_base, model_name,
90
+ load_8bit, load_4bit, device):
91
+ self.controller_addr = controller_addr
92
+ self.worker_addr = worker_addr
93
+ self.worker_id = worker_id
94
+ self.model_path = model_path
95
+ if model_path.endswith("/"):
96
+ model_path = model_path[:-1]
97
+ if model_name is None:
98
+ model_paths = model_path.split("/")
99
+ if model_paths[-1].startswith('checkpoint-'):
100
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
101
+ else:
102
+ self.model_name = model_paths[-1]
103
+ else:
104
+ self.model_name = model_name
105
+
106
+ self.device = device
107
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
108
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
109
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
110
+ self.is_multimodal = 'videollama2' in self.model_name.lower() or 'vlb' in self.model_name.lower()
111
+
112
+ if not no_register:
113
+ self.register_to_controller()
114
+ self.heart_beat_thread = threading.Thread(
115
+ target=heart_beat_worker, args=(self,))
116
+ self.heart_beat_thread.start()
117
+
118
+ def register_to_controller(self):
119
+ logger.info("Register to controller")
120
+
121
+ url = self.controller_addr + "/register_worker"
122
+ data = {
123
+ "worker_name": self.worker_addr,
124
+ "check_heart_beat": True,
125
+ "worker_status": self.get_status()
126
+ }
127
+ r = requests.post(url, json=data)
128
+ assert r.status_code == 200
129
+
130
+ def send_heart_beat(self):
131
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
132
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
133
+ f"global_counter: {global_counter}")
134
+
135
+ url = self.controller_addr + "/receive_heart_beat"
136
+
137
+ while True:
138
+ try:
139
+ ret = requests.post(url, json={
140
+ "worker_name": self.worker_addr,
141
+ "queue_length": self.get_queue_length()}, timeout=5)
142
+ exist = ret.json()["exist"]
143
+ break
144
+ except requests.exceptions.RequestException as e:
145
+ logger.error(f"heart beat error: {e}")
146
+ time.sleep(5)
147
+
148
+ if not exist:
149
+ self.register_to_controller()
150
+
151
+ def get_queue_length(self):
152
+ if model_semaphore is None:
153
+ return 0
154
+ else:
155
+ return args.limit_model_concurrency - model_semaphore._value + (len(
156
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
157
+
158
+ def get_status(self):
159
+ return {
160
+ "model_names": [self.model_name],
161
+ "speed": 1,
162
+ "queue_length": self.get_queue_length(),
163
+ }
164
+
165
+ @torch.inference_mode()
166
+ def generate_stream(self, params):
167
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
168
+
169
+ prompt = params["prompt"]
170
+ ori_prompt = prompt
171
+ images_or_videos = params.get("images", None)
172
+ #print("Input images:", images_or_videos)
173
+ num_image_tokens = 0
174
+ modal_list = []
175
+ if images_or_videos is not None and len(images_or_videos) and self.is_multimodal:
176
+ if len(images_or_videos) > 0:
177
+ if len(images_or_videos) != prompt.count(DEFAULT_IMAGE_TOKEN) and len(images_or_videos) != (prompt.count(DEFAULT_VIDEO_TOKEN)):
178
+ raise ValueError("Number of images/videos does not match number of <image>/<video> tokens in prompt")
179
+
180
+ try:
181
+ print("Load image...")
182
+ images_or_videos = [load_image_from_base64(image) for image in images_or_videos]
183
+ images_or_videos = process_images(images_or_videos, image_processor, model.config)
184
+
185
+ modal_list = ["image"]
186
+ replace_token = DEFAULT_IMAGE_TOKEN
187
+ modal_token_index = MMODAL_TOKEN_INDEX["IMAGE"]
188
+ except:
189
+ print("Load video instead...")
190
+ decord_vr = VideoReader(uri=images_or_videos[0], ctx=cpu(0))
191
+ duration = len(decord_vr)
192
+ if not "use_taug" in self.model_path:
193
+ frame_id_list = np.linspace(0, duration-1, 8, dtype=int)
194
+ video_frames = decord_vr.get_batch(frame_id_list).asnumpy()
195
+ images_or_videos = process_videos(video_frames, image_processor, model.config)
196
+ else:
197
+ print("Temporal augmentation activated!!!")
198
+ frame_id_list = np.linspace(0, duration-1, 8 * 2 * 2, dtype=int)
199
+ video_data = decord_vr.get_batch(frame_id_list)
200
+ video_frames = [Image.fromarray(f) for f in video_data.asnumpy()]
201
+ chunked_video_frames = chunk_list(video_frames, 2*2)
202
+ expanded_video_frames = [frame_expansion(frame_list, 2) for frame_list in chunked_video_frames]
203
+ images_or_videos = process_videos(expanded_video_frames, image_processor, model.config)
204
+
205
+ # frame_id_list = np.linspace(0, duration-1, NUM_FRAMES, dtype=int)
206
+ # images_or_videos = decord_vr.get_batch(frame_id_list).asnumpy()
207
+ # images_or_videos = process_videos(images_or_videos, image_processor, model.config)
208
+ #print("images_or_videos.shape:", images_or_videos.shape)
209
+ modal_list = ["video"]
210
+ replace_token = DEFAULT_VIDEO_TOKEN
211
+ modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
212
+
213
+ if type(images_or_videos) is list:
214
+ images_or_videos = [image.to(self.model.device, dtype=torch.float16) for image in images_or_videos]
215
+ else:
216
+ images_or_videos = images_or_videos.to(self.model.device, dtype=torch.float16)
217
+ if modal_list[0] == "video":
218
+ print("Video:", images_or_videos.shape)
219
+ images_or_videos = [images_or_videos]
220
+ else:
221
+ print("Image:", images_or_videos.shape)
222
+
223
+
224
+ #image_sizes = [image.size for image in images_or_videos]
225
+
226
+
227
+ # if len(images_or_videos) % NUM_FRAMES == 0:
228
+ # images_or_videos = process_images(images_or_videos, image_processor, model.config)
229
+ # #images_or_videos = [image.to(self.model.device, dtype=torch.float16) for image in images_or_videos]
230
+ # #modal_list = ["image"] * len(images_or_videos)
231
+ # images_or_videos = images_or_videos.to(self.model.device, dtype=torch.float16)
232
+ # modal_list = ["video"]
233
+ # replace_token = DEFAULT_VIDEO_TOKEN
234
+ # else:
235
+
236
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
237
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
238
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
239
+
240
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
241
+ else:
242
+ images = None
243
+ modal_list = []
244
+ image_args = {"images_or_videos": images_or_videos, "modal_list": modal_list}
245
+ else:
246
+ images = None
247
+ image_args = {}
248
+ print("image_args:", image_args)
249
+ temperature = float(params.get("temperature", 1.0))
250
+ top_p = float(params.get("top_p", 1.0))
251
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
252
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
253
+ stop_str = params.get("stop", None)
254
+ do_sample = True if temperature > 0.001 else False
255
+
256
+ #input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
257
+ # tokenizer for our video-llama beta
258
+ input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_token_index, return_tensors='pt').unsqueeze(0).to(self.device)
259
+ #print("Current prompt:", prompt)
260
+ #print("input_ids.shape:", input_ids.shape)
261
+ keywords = [stop_str]
262
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
263
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
264
+
265
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
266
+
267
+ if max_new_tokens < 1:
268
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
269
+ return
270
+
271
+ thread = Thread(target=model.generate, kwargs=dict(
272
+ inputs=input_ids,
273
+ do_sample=do_sample,
274
+ temperature=temperature,
275
+ top_p=top_p,
276
+ max_new_tokens=max_new_tokens,
277
+ streamer=streamer,
278
+ stopping_criteria=[stopping_criteria],
279
+ use_cache=True,
280
+ **image_args
281
+ ))
282
+ thread.start()
283
+
284
+ generated_text = ori_prompt
285
+ token_count = 0
286
+ for new_text in streamer:
287
+ generated_text += new_text
288
+ token_count += len(tokenizer.encode(new_text))
289
+ if token_count >= STREAM_CHECK_MULTIPLE:
290
+ safety_message = safety_check(generated_text)
291
+ if safety_message:
292
+ print('####### Keyword alarm triggered:', generated_text)
293
+ yield json.dumps({"text": safety_message , "error_code": 1}).encode() + b"\0"
294
+ return
295
+ token_count = 0 #
296
+
297
+
298
+ if generated_text.endswith(stop_str):
299
+ generated_text = generated_text[:-len(stop_str)]
300
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
301
+
302
+ def generate_stream_gate(self, params):
303
+ try:
304
+ input_text = params.get("prompt", "")
305
+ safety_message = input_safety_check(input_text)
306
+ if safety_message:
307
+ yield json.dumps({"text": safety_message, "error_code": 1}).encode() + b"\0"
308
+ return
309
+
310
+ for x in self.generate_stream(params):
311
+ yield x
312
+ except ValueError as e:
313
+ print("Caught ValueError:", e)
314
+ ret = {
315
+ "text": server_error_msg,
316
+ "error_code": 1,
317
+ }
318
+ yield json.dumps(ret).encode() + b"\0"
319
+ except torch.cuda.CudaError as e:
320
+ print("Caught torch.cuda.CudaError:", e)
321
+ ret = {
322
+ "text": server_error_msg,
323
+ "error_code": 1,
324
+ }
325
+ yield json.dumps(ret).encode() + b"\0"
326
+ except Exception as e:
327
+ print("Caught Unknown Error", e)
328
+ ret = {
329
+ "text": server_error_msg,
330
+ "error_code": 1,
331
+ }
332
+ yield json.dumps(ret).encode() + b"\0"
333
+
334
+
335
+ app = FastAPI()
336
+
337
+
338
+ def release_model_semaphore(fn=None):
339
+ model_semaphore.release()
340
+ if fn is not None:
341
+ fn()
342
+
343
+
344
+ @app.post("/worker_generate_stream")
345
+ async def generate_stream(request: Request):
346
+ global model_semaphore, global_counter
347
+ global_counter += 1
348
+ params = await request.json()
349
+
350
+ if model_semaphore is None:
351
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
352
+ await model_semaphore.acquire()
353
+ worker.send_heart_beat()
354
+ generator = worker.generate_stream_gate(params)
355
+ background_tasks = BackgroundTasks()
356
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
357
+ return StreamingResponse(generator, background=background_tasks)
358
+
359
+
360
+ @app.post("/worker_get_status")
361
+ async def get_status(request: Request):
362
+ return worker.get_status()
363
+
364
+
365
+ if __name__ == "__main__":
366
+ parser = argparse.ArgumentParser()
367
+ parser.add_argument("--host", type=str, default="localhost")
368
+ parser.add_argument("--port", type=int, default=21002)
369
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
370
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
371
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
372
+ parser.add_argument("--model-base", type=str, default=None)
373
+ parser.add_argument("--model-name", type=str)
374
+ parser.add_argument("--device", type=str, default="cuda")
375
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
376
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
377
+ parser.add_argument("--stream-interval", type=int, default=1)
378
+ parser.add_argument("--no-register", action="store_true")
379
+ parser.add_argument("--load-8bit", action="store_true")
380
+ parser.add_argument("--load-4bit", action="store_true")
381
+ args = parser.parse_args()
382
+ logger.info(f"args: {args}")
383
+
384
+ if args.multi_modal:
385
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
386
+
387
+ worker = ModelWorker(args.controller_address,
388
+ args.worker_address,
389
+ worker_id,
390
+ args.no_register,
391
+ args.model_path,
392
+ args.model_base,
393
+ args.model_name,
394
+ args.load_8bit,
395
+ args.load_4bit,
396
+ args.device)
397
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
register_worker.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manually register workers.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--controller-address", type=str)
15
+ parser.add_argument("--worker-name", type=str)
16
+ parser.add_argument("--check-heart-beat", action="store_true")
17
+ args = parser.parse_args()
18
+
19
+ url = args.controller_address + "/register_worker"
20
+ data = {
21
+ "worker_name": args.worker_name,
22
+ "check_heart_beat": args.check_heart_beat,
23
+ "worker_status": None,
24
+ }
25
+ r = requests.post(url, json=data)
26
+ assert r.status_code == 200
sglang_worker.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import json
8
+ import time
9
+ import threading
10
+ import uuid
11
+
12
+ from fastapi import FastAPI, Request, BackgroundTasks
13
+ from fastapi.responses import StreamingResponse
14
+ import requests
15
+ import re
16
+ import uvicorn
17
+ from functools import partial
18
+
19
+ from llava.constants import WORKER_HEART_BEAT_INTERVAL
20
+ from llava.utils import (build_logger, server_error_msg,
21
+ pretty_print_semaphore)
22
+ from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
23
+ from llava.constants import DEFAULT_IMAGE_TOKEN
24
+
25
+ import sglang as sgl
26
+ from sglang.backend.runtime_endpoint import RuntimeEndpoint
27
+
28
+
29
+ GB = 1 << 30
30
+
31
+ worker_id = str(uuid.uuid4())[:6]
32
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
33
+ global_counter = 0
34
+
35
+ model_semaphore = None
36
+
37
+
38
+ def heart_beat_worker(controller):
39
+ while True:
40
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
+ controller.send_heart_beat()
42
+
43
+
44
+ @sgl.function
45
+ def pipeline(s, prompt, max_tokens):
46
+ for p in prompt:
47
+ if type(p) is str:
48
+ s += p
49
+ else:
50
+ s += sgl.image(p)
51
+ s += sgl.gen("response", max_tokens=max_tokens)
52
+
53
+
54
+ class ModelWorker:
55
+ def __init__(self, controller_addr, worker_addr, sgl_endpoint,
56
+ worker_id, no_register, model_name):
57
+ self.controller_addr = controller_addr
58
+ self.worker_addr = worker_addr
59
+ self.worker_id = worker_id
60
+
61
+ # Select backend
62
+ backend = RuntimeEndpoint(sgl_endpoint)
63
+ sgl.set_default_backend(backend)
64
+ model_path = backend.model_info["model_path"]
65
+
66
+ if model_path.endswith("/"):
67
+ model_path = model_path[:-1]
68
+ if model_name is None:
69
+ model_paths = model_path.split("/")
70
+ if model_paths[-1].startswith('checkpoint-'):
71
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
72
+ else:
73
+ self.model_name = model_paths[-1]
74
+ else:
75
+ self.model_name = model_name
76
+
77
+ logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
78
+
79
+ if not no_register:
80
+ self.register_to_controller()
81
+ self.heart_beat_thread = threading.Thread(
82
+ target=heart_beat_worker, args=(self,), daemon=True)
83
+ self.heart_beat_thread.start()
84
+
85
+ def register_to_controller(self):
86
+ logger.info("Register to controller")
87
+
88
+ url = self.controller_addr + "/register_worker"
89
+ data = {
90
+ "worker_name": self.worker_addr,
91
+ "check_heart_beat": True,
92
+ "worker_status": self.get_status()
93
+ }
94
+ r = requests.post(url, json=data)
95
+ assert r.status_code == 200
96
+
97
+ def send_heart_beat(self):
98
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
99
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
100
+ f"global_counter: {global_counter}")
101
+
102
+ url = self.controller_addr + "/receive_heart_beat"
103
+
104
+ while True:
105
+ try:
106
+ ret = requests.post(url, json={
107
+ "worker_name": self.worker_addr,
108
+ "queue_length": self.get_queue_length()}, timeout=5)
109
+ exist = ret.json()["exist"]
110
+ break
111
+ except requests.exceptions.RequestException as e:
112
+ logger.error(f"heart beat error: {e}")
113
+ time.sleep(5)
114
+
115
+ if not exist:
116
+ self.register_to_controller()
117
+
118
+ def get_queue_length(self):
119
+ if model_semaphore is None:
120
+ return 0
121
+ else:
122
+ return args.limit_model_concurrency - model_semaphore._value + (len(
123
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
124
+
125
+ def get_status(self):
126
+ return {
127
+ "model_names": [self.model_name],
128
+ "speed": 1,
129
+ "queue_length": self.get_queue_length(),
130
+ }
131
+
132
+ async def generate_stream(self, params):
133
+ ori_prompt = prompt = params["prompt"]
134
+ images = params.get("images", None)
135
+ if images is not None and len(images) > 0:
136
+ if len(images) > 0:
137
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
138
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
139
+
140
+ images = [load_image_from_base64(image) for image in images]
141
+
142
+ # FIXME: for image-start/end token
143
+ # replace_token = DEFAULT_IMAGE_TOKEN
144
+ # if getattr(self.model.config, 'mm_use_im_start_end', False):
145
+ # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
+ # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
+ prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
148
+ prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
149
+ prompt = []
150
+ for i in range(len(prompt_split)):
151
+ prompt.append(prompt_split[i])
152
+ if i < len(images):
153
+ prompt.append(images[i])
154
+ else:
155
+ prompt = [prompt]
156
+
157
+ temperature = float(params.get("temperature", 1.0))
158
+ top_p = float(params.get("top_p", 1.0))
159
+ # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
160
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
161
+ stop_str = params.get("stop", None)
162
+ stop_str = [stop_str] if stop_str is not None else None
163
+
164
+ print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
165
+ state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
166
+
167
+ generated_text = ori_prompt
168
+ async for text_outputs in state.text_async_iter(var_name="response"):
169
+ generated_text += text_outputs
170
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
171
+
172
+ async def generate_stream_gate(self, params):
173
+ try:
174
+ async for x in self.generate_stream(params):
175
+ yield x
176
+ except ValueError as e:
177
+ print("Caught ValueError:", e)
178
+ ret = {
179
+ "text": server_error_msg,
180
+ "error_code": 1,
181
+ }
182
+ yield json.dumps(ret).encode() + b"\0"
183
+ except Exception as e:
184
+ print("Caught Unknown Error", e)
185
+ ret = {
186
+ "text": server_error_msg,
187
+ "error_code": 1,
188
+ }
189
+ yield json.dumps(ret).encode() + b"\0"
190
+
191
+
192
+ app = FastAPI()
193
+
194
+
195
+ def release_model_semaphore(fn=None):
196
+ model_semaphore.release()
197
+ if fn is not None:
198
+ fn()
199
+
200
+
201
+ @app.post("/worker_generate_stream")
202
+ async def generate_stream(request: Request):
203
+ global model_semaphore, global_counter
204
+ global_counter += 1
205
+ params = await request.json()
206
+
207
+ if model_semaphore is None:
208
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
209
+ await model_semaphore.acquire()
210
+ worker.send_heart_beat()
211
+ generator = worker.generate_stream_gate(params)
212
+ background_tasks = BackgroundTasks()
213
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
214
+ return StreamingResponse(generator, background=background_tasks)
215
+
216
+
217
+ @app.post("/worker_get_status")
218
+ async def get_status(request: Request):
219
+ return worker.get_status()
220
+
221
+
222
+ if __name__ == "__main__":
223
+ parser = argparse.ArgumentParser()
224
+ parser.add_argument("--host", type=str, default="localhost")
225
+ parser.add_argument("--port", type=int, default=21002)
226
+ parser.add_argument("--worker-address", type=str,
227
+ default="http://localhost:21002")
228
+ parser.add_argument("--controller-address", type=str,
229
+ default="http://localhost:21001")
230
+ parser.add_argument("--model-name", type=str)
231
+ parser.add_argument("--sgl-endpoint", type=str)
232
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
233
+ parser.add_argument("--stream-interval", type=int, default=1)
234
+ parser.add_argument("--no-register", action="store_true")
235
+ args = parser.parse_args()
236
+ logger.info(f"args: {args}")
237
+
238
+ worker = ModelWorker(args.controller_address,
239
+ args.worker_address,
240
+ args.sgl_endpoint,
241
+ worker_id,
242
+ args.no_register,
243
+ args.model_name)
244
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
test_message.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import requests
5
+
6
+ from llava.conversation import default_conversation
7
+
8
+
9
+ def main():
10
+ if args.worker_address:
11
+ worker_addr = args.worker_address
12
+ else:
13
+ controller_addr = args.controller_address
14
+ ret = requests.post(controller_addr + "/refresh_all_workers")
15
+ ret = requests.post(controller_addr + "/list_models")
16
+ models = ret.json()["models"]
17
+ models.sort()
18
+ print(f"Models: {models}")
19
+
20
+ ret = requests.post(controller_addr + "/get_worker_address",
21
+ json={"model": args.model_name})
22
+ worker_addr = ret.json()["address"]
23
+ print(f"worker_addr: {worker_addr}")
24
+
25
+ if worker_addr == "":
26
+ return
27
+
28
+ conv = default_conversation.copy()
29
+ conv.append_message(conv.roles[0], args.message)
30
+ prompt = conv.get_prompt()
31
+
32
+ headers = {"User-Agent": "LLaVA Client"}
33
+ pload = {
34
+ "model": args.model_name,
35
+ "prompt": prompt,
36
+ "max_new_tokens": args.max_new_tokens,
37
+ "temperature": 0.7,
38
+ "stop": conv.sep,
39
+ }
40
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
+ json=pload, stream=True)
42
+
43
+ print(prompt.replace(conv.sep, "\n"), end="")
44
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45
+ if chunk:
46
+ data = json.loads(chunk.decode("utf-8"))
47
+ output = data["text"].split(conv.sep)[-1]
48
+ print(output, end="\r")
49
+ print("")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55
+ parser.add_argument("--worker-address", type=str)
56
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
+ parser.add_argument("--max-new-tokens", type=int, default=32)
58
+ parser.add_argument("--message", type=str, default=
59
+ "Tell me a story with more than 1000 words.")
60
+ args = parser.parse_args()
61
+
62
+ main()