Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- README.md +3 -9
- __init__.py +0 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/gradio_utils.cpython-310.pyc +0 -0
- __pycache__/gradio_web_server.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- asset/Model.png +0 -0
- cli.py +142 -0
- controller.py +298 -0
- examples/desert.jpg +0 -0
- examples/extreme_ironing.jpg +0 -0
- examples/sample_demo_1.mp4 +3 -0
- examples/sample_demo_13.mp4 +3 -0
- examples/sample_demo_22.mp4 +3 -0
- examples/sample_demo_3.mp4 +0 -0
- examples/sample_demo_8.mp4 +3 -0
- examples/sample_demo_9.mp4 +0 -0
- examples/sample_img_13.png +0 -0
- examples/sample_img_22.png +0 -0
- examples/sample_img_8.png +3 -0
- examples/waterview.jpg +0 -0
- gradio_utils.py +155 -0
- gradio_web_server copy.py +227 -0
- gradio_web_server.py +234 -0
- model_worker.py +285 -0
- processing_utils.py +99 -0
- register_worker.py +26 -0
- test_message.py +62 -0
- utils.py +16 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/sample_demo_1.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/sample_demo_13.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/sample_demo_22.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/sample_demo_8.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/sample_img_8.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: vlm-rlaif-demo
|
3 |
+
app_file: gradio_web_server.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.35.2
|
|
|
|
|
6 |
---
|
|
|
|
__init__.py
ADDED
File without changes
|
__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (134 Bytes). View file
|
|
__pycache__/gradio_utils.cpython-310.pyc
ADDED
Binary file (5.63 kB). View file
|
|
__pycache__/gradio_web_server.cpython-310.pyc
ADDED
Binary file (5.91 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (603 Bytes). View file
|
|
asset/Model.png
ADDED
cli.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import sys
|
7 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
|
8 |
+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \
|
9 |
+
DEFAULT_VIDEO_TOKEN
|
10 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
11 |
+
from llava.model.builder import load_pretrained_model
|
12 |
+
from llava.utils import disable_torch_init
|
13 |
+
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
14 |
+
from serve.utils import load_image, image_ext, video_ext
|
15 |
+
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
import requests
|
19 |
+
from PIL import Image
|
20 |
+
from io import BytesIO
|
21 |
+
from transformers import TextStreamer
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
# Model
|
27 |
+
disable_torch_init()
|
28 |
+
|
29 |
+
model_name = get_model_name_from_path(args.model_path)
|
30 |
+
tokenizer, model, processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name,
|
31 |
+
args.load_8bit, args.load_4bit,
|
32 |
+
device=args.device, cache_dir=args.cache_dir)
|
33 |
+
image_processor, video_processor = processor['image'], processor['video']
|
34 |
+
if 'llama-2' in model_name.lower():
|
35 |
+
conv_mode = "llava_llama_2"
|
36 |
+
elif "v1" in model_name.lower():
|
37 |
+
conv_mode = "llava_v1"
|
38 |
+
elif "mpt" in model_name.lower():
|
39 |
+
conv_mode = "mpt"
|
40 |
+
else:
|
41 |
+
conv_mode = "llava_v0"
|
42 |
+
|
43 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
44 |
+
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
45 |
+
else:
|
46 |
+
args.conv_mode = conv_mode
|
47 |
+
|
48 |
+
conv = conv_templates[args.conv_mode].copy()
|
49 |
+
if "mpt" in model_name.lower():
|
50 |
+
roles = ('user', 'assistant')
|
51 |
+
else:
|
52 |
+
roles = conv.roles
|
53 |
+
|
54 |
+
tensor = []
|
55 |
+
special_token = []
|
56 |
+
args.file = args.file if isinstance(args.file, list) else [args.file]
|
57 |
+
for file in args.file:
|
58 |
+
if os.path.splitext(file)[-1].lower() in video_ext: # video extension
|
59 |
+
video_tensor = video_processor(file, return_tensors='pt')['pixel_values'][0].to(model.device, dtype=torch.float16)
|
60 |
+
special_token += [DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames
|
61 |
+
elif os.path.splitext(os.listdir(file)[0]).lower() in image_ext: # frames folder
|
62 |
+
vidframes_list = sorted(glob(file + '/*'))
|
63 |
+
images = load_frames(vidframes_list, model.get_video_tower().config.num_frames)
|
64 |
+
# Similar operation in model_worker.py
|
65 |
+
video_tensor = process_images(images, image_processor, args)
|
66 |
+
video_tensor = video_tensor.to(model.device, dtype=torch.float16)
|
67 |
+
video_tensor = video_tensor.unsqueeze(0)
|
68 |
+
special_token += [DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames
|
69 |
+
else:
|
70 |
+
raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(file)[-1].lower()}')
|
71 |
+
print(video_tensor.shape)
|
72 |
+
tensor.append(video_tensor)
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
while True:
|
78 |
+
try:
|
79 |
+
inp = input(f"{roles[0]}: ")
|
80 |
+
except EOFError:
|
81 |
+
inp = ""
|
82 |
+
if not inp:
|
83 |
+
print("exit...")
|
84 |
+
break
|
85 |
+
|
86 |
+
print(f"{roles[1]}: ", end="")
|
87 |
+
|
88 |
+
if file is not None:
|
89 |
+
# first message
|
90 |
+
if getattr(model.config, "mm_use_im_start_end", False):
|
91 |
+
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
|
92 |
+
# inp = ''.join([DEFAULT_IM_START_TOKEN + i + DEFAULT_IM_END_TOKEN for i in special_token]) + '\n' + inp
|
93 |
+
else:
|
94 |
+
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
|
95 |
+
# inp = ''.join(special_token) + '\n' + inp
|
96 |
+
conv.append_message(conv.roles[0], inp)
|
97 |
+
file = None
|
98 |
+
else:
|
99 |
+
# later messages
|
100 |
+
conv.append_message(conv.roles[0], inp)
|
101 |
+
conv.append_message(conv.roles[1], None)
|
102 |
+
prompt = conv.get_prompt()
|
103 |
+
|
104 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
105 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
106 |
+
keywords = [stop_str]
|
107 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
108 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
109 |
+
|
110 |
+
with torch.inference_mode():
|
111 |
+
output_ids = model.generate(
|
112 |
+
input_ids,
|
113 |
+
images=tensor, # video as fake images
|
114 |
+
do_sample=True if args.temperature > 0 else False,
|
115 |
+
temperature=args.temperature,
|
116 |
+
max_new_tokens=args.max_new_tokens,
|
117 |
+
streamer=streamer,
|
118 |
+
use_cache=True,
|
119 |
+
stopping_criteria=[stopping_criteria])
|
120 |
+
|
121 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
122 |
+
conv.messages[-1][-1] = outputs
|
123 |
+
|
124 |
+
if args.debug:
|
125 |
+
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
parser = argparse.ArgumentParser()
|
130 |
+
parser.add_argument("--model-path", type=str, default="LanguageBind/Video-LLaVA-7B")
|
131 |
+
parser.add_argument("--model-base", type=str, default=None)
|
132 |
+
parser.add_argument("--cache-dir", type=str, default=None)
|
133 |
+
parser.add_argument("--file", nargs='+', type=str, required=True)
|
134 |
+
parser.add_argument("--device", type=str, default="cuda")
|
135 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
136 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
137 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
138 |
+
parser.add_argument("--load-8bit", action="store_true")
|
139 |
+
parser.add_argument("--load-4bit", action="store_true")
|
140 |
+
parser.add_argument("--debug", action="store_true")
|
141 |
+
args = parser.parse_args()
|
142 |
+
main(args)
|
controller.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import time
|
12 |
+
from typing import List, Union
|
13 |
+
import threading
|
14 |
+
|
15 |
+
from fastapi import FastAPI, Request
|
16 |
+
from fastapi.responses import StreamingResponse
|
17 |
+
import numpy as np
|
18 |
+
import requests
|
19 |
+
import uvicorn
|
20 |
+
|
21 |
+
from videollava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
+
from videollava.utils import build_logger, server_error_msg
|
23 |
+
|
24 |
+
|
25 |
+
logger = build_logger("controller", "controller.log")
|
26 |
+
|
27 |
+
|
28 |
+
class DispatchMethod(Enum):
|
29 |
+
LOTTERY = auto()
|
30 |
+
SHORTEST_QUEUE = auto()
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_str(cls, name):
|
34 |
+
if name == "lottery":
|
35 |
+
return cls.LOTTERY
|
36 |
+
elif name == "shortest_queue":
|
37 |
+
return cls.SHORTEST_QUEUE
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Invalid dispatch method")
|
40 |
+
|
41 |
+
|
42 |
+
@dataclasses.dataclass
|
43 |
+
class WorkerInfo:
|
44 |
+
model_names: List[str]
|
45 |
+
speed: int
|
46 |
+
queue_length: int
|
47 |
+
check_heart_beat: bool
|
48 |
+
last_heart_beat: str
|
49 |
+
|
50 |
+
|
51 |
+
def heart_beat_controller(controller):
|
52 |
+
while True:
|
53 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
54 |
+
controller.remove_stable_workers_by_expiration()
|
55 |
+
|
56 |
+
|
57 |
+
class Controller:
|
58 |
+
def __init__(self, dispatch_method: str):
|
59 |
+
# Dict[str -> WorkerInfo]
|
60 |
+
self.worker_info = {}
|
61 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
62 |
+
|
63 |
+
self.heart_beat_thread = threading.Thread(
|
64 |
+
target=heart_beat_controller, args=(self,))
|
65 |
+
self.heart_beat_thread.start()
|
66 |
+
|
67 |
+
logger.info("Init controller")
|
68 |
+
|
69 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
70 |
+
worker_status: dict):
|
71 |
+
if worker_name not in self.worker_info:
|
72 |
+
logger.info(f"Register a new worker: {worker_name}")
|
73 |
+
else:
|
74 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
75 |
+
|
76 |
+
if not worker_status:
|
77 |
+
worker_status = self.get_worker_status(worker_name)
|
78 |
+
if not worker_status:
|
79 |
+
return False
|
80 |
+
|
81 |
+
self.worker_info[worker_name] = WorkerInfo(
|
82 |
+
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
83 |
+
check_heart_beat, time.time())
|
84 |
+
|
85 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
+
return True
|
87 |
+
|
88 |
+
def get_worker_status(self, worker_name: str):
|
89 |
+
try:
|
90 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
91 |
+
except requests.exceptions.RequestException as e:
|
92 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
93 |
+
return None
|
94 |
+
|
95 |
+
if r.status_code != 200:
|
96 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
97 |
+
return None
|
98 |
+
|
99 |
+
return r.json()
|
100 |
+
|
101 |
+
def remove_worker(self, worker_name: str):
|
102 |
+
del self.worker_info[worker_name]
|
103 |
+
|
104 |
+
def refresh_all_workers(self):
|
105 |
+
old_info = dict(self.worker_info)
|
106 |
+
self.worker_info = {}
|
107 |
+
|
108 |
+
for w_name, w_info in old_info.items():
|
109 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
110 |
+
logger.info(f"Remove stale worker: {w_name}")
|
111 |
+
|
112 |
+
def list_models(self):
|
113 |
+
model_names = set()
|
114 |
+
|
115 |
+
for w_name, w_info in self.worker_info.items():
|
116 |
+
model_names.update(w_info.model_names)
|
117 |
+
|
118 |
+
return list(model_names)
|
119 |
+
|
120 |
+
def get_worker_address(self, model_name: str):
|
121 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
122 |
+
worker_names = []
|
123 |
+
worker_speeds = []
|
124 |
+
for w_name, w_info in self.worker_info.items():
|
125 |
+
if model_name in w_info.model_names:
|
126 |
+
worker_names.append(w_name)
|
127 |
+
worker_speeds.append(w_info.speed)
|
128 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
129 |
+
norm = np.sum(worker_speeds)
|
130 |
+
if norm < 1e-4:
|
131 |
+
return ""
|
132 |
+
worker_speeds = worker_speeds / norm
|
133 |
+
if True: # Directly return address
|
134 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
135 |
+
p=worker_speeds)
|
136 |
+
worker_name = worker_names[pt]
|
137 |
+
return worker_name
|
138 |
+
|
139 |
+
# Check status before returning
|
140 |
+
while True:
|
141 |
+
pt = np.random.choice(np.arange(len(worker_names)),
|
142 |
+
p=worker_speeds)
|
143 |
+
worker_name = worker_names[pt]
|
144 |
+
|
145 |
+
if self.get_worker_status(worker_name):
|
146 |
+
break
|
147 |
+
else:
|
148 |
+
self.remove_worker(worker_name)
|
149 |
+
worker_speeds[pt] = 0
|
150 |
+
norm = np.sum(worker_speeds)
|
151 |
+
if norm < 1e-4:
|
152 |
+
return ""
|
153 |
+
worker_speeds = worker_speeds / norm
|
154 |
+
continue
|
155 |
+
return worker_name
|
156 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
157 |
+
worker_names = []
|
158 |
+
worker_qlen = []
|
159 |
+
for w_name, w_info in self.worker_info.items():
|
160 |
+
if model_name in w_info.model_names:
|
161 |
+
worker_names.append(w_name)
|
162 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
163 |
+
if len(worker_names) == 0:
|
164 |
+
return ""
|
165 |
+
min_index = np.argmin(worker_qlen)
|
166 |
+
w_name = worker_names[min_index]
|
167 |
+
self.worker_info[w_name].queue_length += 1
|
168 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
169 |
+
return w_name
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
172 |
+
|
173 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
174 |
+
if worker_name not in self.worker_info:
|
175 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
176 |
+
return False
|
177 |
+
|
178 |
+
self.worker_info[worker_name].queue_length = queue_length
|
179 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
180 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
181 |
+
return True
|
182 |
+
|
183 |
+
def remove_stable_workers_by_expiration(self):
|
184 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
185 |
+
to_delete = []
|
186 |
+
for worker_name, w_info in self.worker_info.items():
|
187 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
188 |
+
to_delete.append(worker_name)
|
189 |
+
|
190 |
+
for worker_name in to_delete:
|
191 |
+
self.remove_worker(worker_name)
|
192 |
+
|
193 |
+
def worker_api_generate_stream(self, params):
|
194 |
+
worker_addr = self.get_worker_address(params["model"])
|
195 |
+
if not worker_addr:
|
196 |
+
logger.info(f"no worker: {params['model']}")
|
197 |
+
ret = {
|
198 |
+
"text": server_error_msg,
|
199 |
+
"error_code": 2,
|
200 |
+
}
|
201 |
+
yield json.dumps(ret).encode() + b"\0"
|
202 |
+
|
203 |
+
try:
|
204 |
+
response = requests.post(worker_addr + "/worker_generate_stream",
|
205 |
+
json=params, stream=True, timeout=5)
|
206 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
207 |
+
if chunk:
|
208 |
+
yield chunk + b"\0"
|
209 |
+
except requests.exceptions.RequestException as e:
|
210 |
+
logger.info(f"worker timeout: {worker_addr}")
|
211 |
+
ret = {
|
212 |
+
"text": server_error_msg,
|
213 |
+
"error_code": 3,
|
214 |
+
}
|
215 |
+
yield json.dumps(ret).encode() + b"\0"
|
216 |
+
|
217 |
+
|
218 |
+
# Let the controller act as a worker to achieve hierarchical
|
219 |
+
# management. This can be used to connect isolated sub networks.
|
220 |
+
def worker_api_get_status(self):
|
221 |
+
model_names = set()
|
222 |
+
speed = 0
|
223 |
+
queue_length = 0
|
224 |
+
|
225 |
+
for w_name in self.worker_info:
|
226 |
+
worker_status = self.get_worker_status(w_name)
|
227 |
+
if worker_status is not None:
|
228 |
+
model_names.update(worker_status["model_names"])
|
229 |
+
speed += worker_status["speed"]
|
230 |
+
queue_length += worker_status["queue_length"]
|
231 |
+
|
232 |
+
return {
|
233 |
+
"model_names": list(model_names),
|
234 |
+
"speed": speed,
|
235 |
+
"queue_length": queue_length,
|
236 |
+
}
|
237 |
+
|
238 |
+
|
239 |
+
app = FastAPI()
|
240 |
+
|
241 |
+
|
242 |
+
@app.post("/register_worker")
|
243 |
+
async def register_worker(request: Request):
|
244 |
+
data = await request.json()
|
245 |
+
controller.register_worker(
|
246 |
+
data["worker_name"], data["check_heart_beat"],
|
247 |
+
data.get("worker_status", None))
|
248 |
+
|
249 |
+
|
250 |
+
@app.post("/refresh_all_workers")
|
251 |
+
async def refresh_all_workers():
|
252 |
+
models = controller.refresh_all_workers()
|
253 |
+
|
254 |
+
|
255 |
+
@app.post("/list_models")
|
256 |
+
async def list_models():
|
257 |
+
models = controller.list_models()
|
258 |
+
return {"models": models}
|
259 |
+
|
260 |
+
|
261 |
+
@app.post("/get_worker_address")
|
262 |
+
async def get_worker_address(request: Request):
|
263 |
+
data = await request.json()
|
264 |
+
addr = controller.get_worker_address(data["model"])
|
265 |
+
return {"address": addr}
|
266 |
+
|
267 |
+
|
268 |
+
@app.post("/receive_heart_beat")
|
269 |
+
async def receive_heart_beat(request: Request):
|
270 |
+
data = await request.json()
|
271 |
+
exist = controller.receive_heart_beat(
|
272 |
+
data["worker_name"], data["queue_length"])
|
273 |
+
return {"exist": exist}
|
274 |
+
|
275 |
+
|
276 |
+
@app.post("/worker_generate_stream")
|
277 |
+
async def worker_api_generate_stream(request: Request):
|
278 |
+
params = await request.json()
|
279 |
+
generator = controller.worker_api_generate_stream(params)
|
280 |
+
return StreamingResponse(generator)
|
281 |
+
|
282 |
+
|
283 |
+
@app.post("/worker_get_status")
|
284 |
+
async def worker_api_get_status(request: Request):
|
285 |
+
return controller.worker_api_get_status()
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == "__main__":
|
289 |
+
parser = argparse.ArgumentParser()
|
290 |
+
parser.add_argument("--host", type=str, default="localhost")
|
291 |
+
parser.add_argument("--port", type=int, default=21001)
|
292 |
+
parser.add_argument("--dispatch-method", type=str, choices=[
|
293 |
+
"lottery", "shortest_queue"], default="shortest_queue")
|
294 |
+
args = parser.parse_args()
|
295 |
+
logger.info(f"args: {args}")
|
296 |
+
|
297 |
+
controller = Controller(args.dispatch_method)
|
298 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
examples/desert.jpg
ADDED
examples/extreme_ironing.jpg
ADDED
examples/sample_demo_1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc6562a172eb9cb3c760a3c9992349c1faa2c793c112b7b9e50bd5cb17c2164d
|
3 |
+
size 1549315
|
examples/sample_demo_13.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13384915331bf749fa31e2f4cbbd85ca90439b81b2390b4b512bd24b0dbd8bae
|
3 |
+
size 19356822
|
examples/sample_demo_22.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dcde24b3e67ff23aafd4b69854dbc7e2485eae65999c86c1beb9160d53fa2a11
|
3 |
+
size 1505931
|
examples/sample_demo_3.mp4
ADDED
Binary file (464 kB). View file
|
|
examples/sample_demo_8.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:618bb02562c769303b797ae3c29a66e15dcc0134d673747e8cf90582369c59a2
|
3 |
+
size 29771700
|
examples/sample_demo_9.mp4
ADDED
Binary file (632 kB). View file
|
|
examples/sample_img_13.png
ADDED
examples/sample_img_22.png
ADDED
examples/sample_img_8.png
ADDED
Git LFS Details
|
examples/waterview.jpg
ADDED
gradio_utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import TextStreamer
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
|
7 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
8 |
+
from llava.conversation import conv_templates, SeparatorStyle
|
9 |
+
from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token
|
10 |
+
from llava.model.builder import load_pretrained_model
|
11 |
+
from llava.utils import disable_torch_init
|
12 |
+
import shutil
|
13 |
+
|
14 |
+
# <a href="https://github.com/SNUMPR/vlm-rlaif.git" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
|
15 |
+
# <img src="https://z1.ax1x.com/2023/11/07/pil4sqH.png" alt="VLM-RLAIF" style="max-width: 120px; height: auto;">
|
16 |
+
# </a>
|
17 |
+
|
18 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
title_markdown = ("""
|
20 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
21 |
+
<img src="/dataset/dcahn/yura/vlm-rlaif/asset/Model.png" alt="VLM-RLAIF" style="max-width: 120px; height: auto;">
|
22 |
+
<img src="file:/dataset/dcahn/yura/vlm-rlaif/asset/Model.png" alt="VLM-RLAIF" style="max-width: 120px; height: auto;">
|
23 |
+
<div>
|
24 |
+
<h1 >VLM-RLAIF: Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback (ACL 2024 Oral) </h1>
|
25 |
+
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
|
26 |
+
</div>
|
27 |
+
</div>
|
28 |
+
|
29 |
+
|
30 |
+
<div align="center">
|
31 |
+
<div style="display:flex; gap: 0.25rem;" align="center">
|
32 |
+
<a href='https://github.com/SNUMPR/vlm-rlaif'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
33 |
+
<a href="https://arxiv.org/abs/2402.03746"><img src="https://img.shields.io/badge/Paper-arxiv-green"></a>
|
34 |
+
</div>
|
35 |
+
</div>
|
36 |
+
""")
|
37 |
+
# <a href='https://github.com/PKU-YuanGroup/Video-LLaVA/stargazers'><img src='https://img.shields.io/github/stars/PKU-YuanGroup/Video-LLaVA.svg?style=social'></a> # arXiv 버튼 옆에 추가?
|
38 |
+
|
39 |
+
block_css = """
|
40 |
+
#buttons button {
|
41 |
+
min-width: min(120px,100%);
|
42 |
+
}
|
43 |
+
"""
|
44 |
+
|
45 |
+
tos_markdown = ("""
|
46 |
+
### Terms of use
|
47 |
+
By using this service, users are required to agree to the following terms:
|
48 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
49 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
50 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
51 |
+
""")
|
52 |
+
|
53 |
+
learn_more_markdown = ("""
|
54 |
+
### License
|
55 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
56 |
+
""")
|
57 |
+
|
58 |
+
|
59 |
+
class Chat:
|
60 |
+
def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None):
|
61 |
+
# model_base = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_SFT'
|
62 |
+
# model_base='/dataset/yura/vlm-rlaif/pretrained/llava-v1.5-7b-lora_w_lora_16_sftv2_short1632_and_then_long_rank32_alpha32_lr1e4_allmodels/SFT_merged'
|
63 |
+
# model_path = '/dataset/yura/vlm-rlaif/pretrained/LLaVA_Video-RL-Fact-RLHF-7b_SFTv2_RM_13b_v1_40k-v1.5-336-lora-padding/checkpoint-180/adapter_model/lora_policy'
|
64 |
+
|
65 |
+
disable_torch_init()
|
66 |
+
model_name = get_model_name_from_path(model_path)
|
67 |
+
# self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name,
|
68 |
+
# load_8bit, load_4bit,
|
69 |
+
# device=device, cache_dir=cache_dir)
|
70 |
+
is_rlhf_checkpoint = 'rlhf' in model_path.lower()
|
71 |
+
print("MODEL_PATH", model_path)
|
72 |
+
print("RLHF Checkpoint: ", is_rlhf_checkpoint)
|
73 |
+
if not model_base or model_base == "none": model_base = None
|
74 |
+
if is_rlhf_checkpoint:
|
75 |
+
model_name = model_path
|
76 |
+
print("Config?", os.path.exists(os.path.join(model_path, "config.json")))
|
77 |
+
if not os.path.exists(os.path.join(model_path, "config.json")):
|
78 |
+
print("Copying")
|
79 |
+
shutil.copy(os.path.join(model_base, "config.json"), os.path.join(model_path, "config.json")) # Copy SFT model's config -> to RLHF folder
|
80 |
+
print("Listed", os.listdir(model_path))
|
81 |
+
print("Copying done")
|
82 |
+
# return(model_name)
|
83 |
+
# return
|
84 |
+
# self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device)
|
85 |
+
self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, False, False, device=device)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
self.image_processor = image_processor
|
90 |
+
# self.image_processor = processor['image']
|
91 |
+
# self.video_processor = processor['video']
|
92 |
+
self.conv_mode = conv_mode
|
93 |
+
self.conv = conv_templates[conv_mode].copy()
|
94 |
+
self.device = self.model.device
|
95 |
+
print(self.model)
|
96 |
+
|
97 |
+
def get_prompt(self, qs, state):
|
98 |
+
state.append_message(state.roles[0], qs)
|
99 |
+
state.append_message(state.roles[1], None)
|
100 |
+
return state
|
101 |
+
|
102 |
+
def _get_latest_prompt(self, state):
|
103 |
+
new_state = state.copy()
|
104 |
+
new_state.messages = state.messages[-2:]
|
105 |
+
return new_state
|
106 |
+
|
107 |
+
@torch.inference_mode()
|
108 |
+
# def generate(self, images_tensor: list, prompt: str, first_run: bool, state):
|
109 |
+
def generate(self, images_tensor: torch.Tensor, prompt: str, first_run: bool, state):
|
110 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
111 |
+
|
112 |
+
state = self.get_prompt(prompt, state)
|
113 |
+
# prompt = state.get_prompt()
|
114 |
+
latest_state = self._get_latest_prompt(state)
|
115 |
+
prompt = latest_state.get_prompt()
|
116 |
+
|
117 |
+
# print('\n\n\n')
|
118 |
+
# print(prompt)
|
119 |
+
|
120 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
121 |
+
|
122 |
+
temperature = 0.2
|
123 |
+
|
124 |
+
max_new_tokens = 1024
|
125 |
+
|
126 |
+
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
|
127 |
+
keywords = [stop_str]
|
128 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
129 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
130 |
+
print(prompt, input_ids.shape, images_tensor.shape)
|
131 |
+
# print(images_tensor)
|
132 |
+
with torch.inference_mode():
|
133 |
+
output_ids = model.generate(
|
134 |
+
input_ids,
|
135 |
+
images=images_tensor,
|
136 |
+
do_sample=True,
|
137 |
+
temperature=temperature,
|
138 |
+
max_new_tokens=max_new_tokens,
|
139 |
+
streamer=streamer,
|
140 |
+
use_cache=True,
|
141 |
+
stopping_criteria=[stopping_criteria])
|
142 |
+
|
143 |
+
input_token_len = input_ids.shape[1]
|
144 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
145 |
+
if n_diff_input_output > 0:
|
146 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
147 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
148 |
+
outputs = outputs.strip()
|
149 |
+
outputs = outputs.replace("QA_GT_caption_based_noisy", "")
|
150 |
+
if outputs.endswith(stop_str):
|
151 |
+
outputs = outputs[:-len(stop_str)]
|
152 |
+
outputs = outputs.strip()
|
153 |
+
|
154 |
+
print('response', outputs)
|
155 |
+
return outputs, state
|
gradio_web_server copy.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from fastapi import FastAPI
|
7 |
+
import os
|
8 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
9 |
+
from PIL import Image
|
10 |
+
import tempfile
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
from transformers import TextStreamer
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
import sys
|
16 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
|
17 |
+
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
18 |
+
from llava.conversation import conv_templates, SeparatorStyle, Conversation
|
19 |
+
from llava.mm_utils import process_images
|
20 |
+
|
21 |
+
from Evaluation.infer_utils import load_video_into_frames
|
22 |
+
from serve.utils import load_image, image_ext, video_ext
|
23 |
+
from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def save_image_to_local(image):
|
28 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
|
29 |
+
image = Image.open(image)
|
30 |
+
image.save(filename)
|
31 |
+
# print(filename)
|
32 |
+
return filename
|
33 |
+
|
34 |
+
|
35 |
+
def save_video_to_local(video_path):
|
36 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
|
37 |
+
shutil.copyfile(video_path, filename)
|
38 |
+
return filename
|
39 |
+
|
40 |
+
|
41 |
+
def generate(image1, video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
|
42 |
+
# ======= manually clear the conversation
|
43 |
+
# state = conv_templates[conv_mode].copy()
|
44 |
+
# state_ = conv_templates[conv_mode].copy()
|
45 |
+
# # =======
|
46 |
+
flag = 1
|
47 |
+
if not textbox_in:
|
48 |
+
if len(state_.messages) > 0:
|
49 |
+
textbox_in = state_.messages[-1][1]
|
50 |
+
state_.messages.pop(-1)
|
51 |
+
flag = 0
|
52 |
+
else:
|
53 |
+
return "Please enter instruction"
|
54 |
+
print("Video", video) # 잘 들어감
|
55 |
+
print("Images_tensor", images_tensor) # None
|
56 |
+
print("Textbox_IN", textbox_in) # 잘 들어감
|
57 |
+
print("State", state) # None
|
58 |
+
print("State_", state_) # None
|
59 |
+
# print(len(state_.messages))
|
60 |
+
|
61 |
+
video = video if video else "none"
|
62 |
+
|
63 |
+
if type(state) is not Conversation:
|
64 |
+
state = conv_templates[conv_mode].copy()
|
65 |
+
state_ = conv_templates[conv_mode].copy()
|
66 |
+
images_tensor = []
|
67 |
+
|
68 |
+
first_run = False if len(state.messages) > 0 else True
|
69 |
+
|
70 |
+
text_en_in = textbox_in.replace("picture", "image")
|
71 |
+
|
72 |
+
image_processor = handler.image_processor
|
73 |
+
assert os.path.exists(video)
|
74 |
+
if os.path.splitext(video)[-1].lower() in video_ext: # video extension
|
75 |
+
video_decode_backend = 'opencv'
|
76 |
+
elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
|
77 |
+
video_decode_backend = 'frames'
|
78 |
+
else:
|
79 |
+
raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
|
80 |
+
|
81 |
+
frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
|
82 |
+
tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
|
83 |
+
# tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
|
84 |
+
# print(tensor.shape)
|
85 |
+
tensor = tensor.to(handler.model.device, dtype=dtype)
|
86 |
+
# images_tensor.append(tensor)
|
87 |
+
images_tensor = tensor
|
88 |
+
|
89 |
+
if handler.model.config.mm_use_im_start_end:
|
90 |
+
text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
|
91 |
+
else:
|
92 |
+
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
|
93 |
+
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
|
94 |
+
state_.messages[-1] = (state_.roles[1], text_en_out)
|
95 |
+
|
96 |
+
text_en_out = text_en_out.split('#')[0]
|
97 |
+
textbox_out = text_en_out
|
98 |
+
|
99 |
+
show_images = ""
|
100 |
+
if os.path.exists(video):
|
101 |
+
filename = save_video_to_local(video)
|
102 |
+
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
|
103 |
+
if flag:
|
104 |
+
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
|
105 |
+
state.append_message(state.roles[1], textbox_out)
|
106 |
+
|
107 |
+
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(video) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
|
108 |
+
|
109 |
+
|
110 |
+
def regenerate(state, state_):
|
111 |
+
state.messages.pop(-1)
|
112 |
+
state_.messages.pop(-1)
|
113 |
+
if len(state.messages) > 0:
|
114 |
+
return state, state_, state.to_gradio_chatbot(), False
|
115 |
+
return (state, state_, state.to_gradio_chatbot(), True)
|
116 |
+
|
117 |
+
|
118 |
+
def clear_history(state, state_):
|
119 |
+
state = conv_templates[conv_mode].copy()
|
120 |
+
state_ = conv_templates[conv_mode].copy()
|
121 |
+
return (gr.update(value=None, interactive=True),
|
122 |
+
gr.update(value=None, interactive=True), \
|
123 |
+
gr.update(value=None, interactive=True), \
|
124 |
+
True, state, state_, state.to_gradio_chatbot(), [])
|
125 |
+
|
126 |
+
|
127 |
+
# ==== CHANGE HERE ====
|
128 |
+
# conv_mode = "llava_v1"
|
129 |
+
# model_path = 'LanguageBind/Video-LLaVA-7B'
|
130 |
+
# FIXME!!!
|
131 |
+
|
132 |
+
conv_mode = "llava_v0"
|
133 |
+
model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
|
134 |
+
# model_path = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_VLM_RLAIF_merged'
|
135 |
+
cache_dir = './cache_dir'
|
136 |
+
device = 'cuda'
|
137 |
+
# device = 'cpu'
|
138 |
+
load_8bit = True
|
139 |
+
load_4bit = False
|
140 |
+
dtype = torch.float16
|
141 |
+
# =============
|
142 |
+
|
143 |
+
handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
|
144 |
+
# handler.model.to(dtype=dtype)
|
145 |
+
if not os.path.exists("temp"):
|
146 |
+
os.makedirs("temp")
|
147 |
+
|
148 |
+
app = FastAPI()
|
149 |
+
|
150 |
+
|
151 |
+
textbox = gr.Textbox(
|
152 |
+
show_label=False, placeholder="Enter text and press ENTER", container=False
|
153 |
+
)
|
154 |
+
with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
|
155 |
+
gr.Markdown(title_markdown)
|
156 |
+
state = gr.State()
|
157 |
+
state_ = gr.State()
|
158 |
+
first_run = gr.State()
|
159 |
+
images_tensor = gr.State()
|
160 |
+
|
161 |
+
image1 = gr.Image(label="Input Image", type="filepath")
|
162 |
+
with gr.Row():
|
163 |
+
with gr.Column(scale=3):
|
164 |
+
video = gr.Video(label="Input Video")
|
165 |
+
|
166 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
167 |
+
gr.Examples(
|
168 |
+
examples=[
|
169 |
+
[
|
170 |
+
f"{cur_dir}/examples/sample_demo_1.mp4",
|
171 |
+
"Why is this video funny?",
|
172 |
+
],
|
173 |
+
[
|
174 |
+
f"{cur_dir}/examples/sample_demo_3.mp4",
|
175 |
+
"Can you identify any safety hazards in this video?"
|
176 |
+
],
|
177 |
+
[
|
178 |
+
f"{cur_dir}/examples/sample_demo_9.mp4",
|
179 |
+
"Describe the video.",
|
180 |
+
],
|
181 |
+
[
|
182 |
+
f"{cur_dir}/examples/sample_demo_22.mp4",
|
183 |
+
"Describe the activity in the video.",
|
184 |
+
],
|
185 |
+
],
|
186 |
+
inputs=[video, textbox],
|
187 |
+
)
|
188 |
+
|
189 |
+
with gr.Column(scale=7):
|
190 |
+
chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column(scale=8):
|
193 |
+
textbox.render()
|
194 |
+
with gr.Column(scale=1, min_width=50):
|
195 |
+
submit_btn = gr.Button(
|
196 |
+
value="Send", variant="primary", interactive=True
|
197 |
+
)
|
198 |
+
with gr.Row(elem_id="buttons") as button_row:
|
199 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
200 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
|
201 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
|
202 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
203 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
204 |
+
# clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
205 |
+
|
206 |
+
gr.Markdown(tos_markdown)
|
207 |
+
gr.Markdown(learn_more_markdown)
|
208 |
+
|
209 |
+
submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor],
|
210 |
+
[state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
|
211 |
+
# submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
|
212 |
+
# [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
213 |
+
|
214 |
+
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
|
215 |
+
generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
|
216 |
+
# generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
217 |
+
|
218 |
+
# clear_btn.click(clear_history, [state, state_],
|
219 |
+
# [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
|
220 |
+
# [video, textbox, first_run, state, state_, chatbot, images_tensor])
|
221 |
+
|
222 |
+
# app = gr.mount_gradio_app(app, demo, path="/")
|
223 |
+
# demo.launch(share=True)
|
224 |
+
demo.launch()
|
225 |
+
|
226 |
+
# uvicorn videollava.serve.gradio_web_server:app
|
227 |
+
# python -m videollava.serve.gradio_web_server
|
gradio_web_server.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from fastapi import FastAPI
|
7 |
+
import os
|
8 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
9 |
+
from PIL import Image
|
10 |
+
import tempfile
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
from transformers import TextStreamer
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
import sys
|
16 |
+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
|
17 |
+
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
18 |
+
from llava.conversation import conv_templates, SeparatorStyle, Conversation
|
19 |
+
from llava.mm_utils import process_images
|
20 |
+
|
21 |
+
from Evaluation.infer_utils import load_video_into_frames
|
22 |
+
from serve.utils import load_image, image_ext, video_ext
|
23 |
+
from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def save_image_to_local(image):
|
28 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
|
29 |
+
image = Image.open(image)
|
30 |
+
image.save(filename)
|
31 |
+
# print(filename)
|
32 |
+
return filename
|
33 |
+
|
34 |
+
|
35 |
+
def save_video_to_local(video_path):
|
36 |
+
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
|
37 |
+
shutil.copyfile(video_path, filename)
|
38 |
+
return filename
|
39 |
+
|
40 |
+
|
41 |
+
def generate(video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
|
42 |
+
# ======= manually clear the conversation
|
43 |
+
# state = conv_templates[conv_mode].copy()
|
44 |
+
# state_ = conv_templates[conv_mode].copy()
|
45 |
+
# # =======
|
46 |
+
flag = 1
|
47 |
+
if not textbox_in:
|
48 |
+
if len(state_.messages) > 0:
|
49 |
+
textbox_in = state_.messages[-1][1]
|
50 |
+
state_.messages.pop(-1)
|
51 |
+
flag = 0
|
52 |
+
else:
|
53 |
+
return "Please enter instruction"
|
54 |
+
# else:
|
55 |
+
# if state is not None and state_ is not None:
|
56 |
+
# # reset conversations
|
57 |
+
# state.messages = []
|
58 |
+
# state_.messages = []
|
59 |
+
|
60 |
+
print("Video", video) # 잘 들어감
|
61 |
+
print("Images_tensor", images_tensor) # None
|
62 |
+
print("Textbox_IN", textbox_in) # 잘 들어감
|
63 |
+
print("State", state) # None
|
64 |
+
print("State_", state_) # None
|
65 |
+
# print(len(state_.messages))
|
66 |
+
|
67 |
+
video = video if video else "none"
|
68 |
+
|
69 |
+
if type(state) is not Conversation:
|
70 |
+
state = conv_templates[conv_mode].copy()
|
71 |
+
state_ = conv_templates[conv_mode].copy()
|
72 |
+
images_tensor = []
|
73 |
+
|
74 |
+
first_run = False if len(state.messages) > 0 else True
|
75 |
+
|
76 |
+
text_en_in = textbox_in.replace("picture", "image")
|
77 |
+
|
78 |
+
image_processor = handler.image_processor
|
79 |
+
assert os.path.exists(video)
|
80 |
+
if os.path.splitext(video)[-1].lower() in video_ext: # video extension
|
81 |
+
video_decode_backend = 'opencv'
|
82 |
+
elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
|
83 |
+
video_decode_backend = 'frames'
|
84 |
+
else:
|
85 |
+
raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
|
86 |
+
|
87 |
+
frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
|
88 |
+
tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
|
89 |
+
# tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
|
90 |
+
# print(tensor.shape)
|
91 |
+
tensor = tensor.to(handler.model.device, dtype=dtype)
|
92 |
+
# images_tensor.append(tensor)
|
93 |
+
images_tensor = tensor
|
94 |
+
|
95 |
+
if handler.model.config.mm_use_im_start_end:
|
96 |
+
text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
|
97 |
+
else:
|
98 |
+
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
|
99 |
+
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
|
100 |
+
state_.messages[-1] = (state_.roles[1], text_en_out)
|
101 |
+
|
102 |
+
text_en_out = text_en_out.split('#')[0]
|
103 |
+
textbox_out = text_en_out
|
104 |
+
|
105 |
+
show_images = ""
|
106 |
+
if os.path.exists(video):
|
107 |
+
filename = save_video_to_local(video)
|
108 |
+
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
|
109 |
+
if flag:
|
110 |
+
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
|
111 |
+
state.append_message(state.roles[1], textbox_out)
|
112 |
+
|
113 |
+
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, \
|
114 |
+
gr.update(value=video if os.path.exists(video) else None, interactive=True))
|
115 |
+
|
116 |
+
|
117 |
+
def regenerate(state, state_):
|
118 |
+
state.messages.pop(-1)
|
119 |
+
state_.messages.pop(-1)
|
120 |
+
if len(state.messages) > 0:
|
121 |
+
return state, state_, state.to_gradio_chatbot(), False
|
122 |
+
return (state, state_, state.to_gradio_chatbot(), True)
|
123 |
+
|
124 |
+
|
125 |
+
def clear_history(state, state_):
|
126 |
+
state = conv_templates[conv_mode].copy()
|
127 |
+
state_ = conv_templates[conv_mode].copy()
|
128 |
+
return (gr.update(value=None, interactive=True),
|
129 |
+
gr.update(value=None, interactive=True), \
|
130 |
+
gr.update(value=None, interactive=True), \
|
131 |
+
True, state, state_, state.to_gradio_chatbot(), [])
|
132 |
+
|
133 |
+
|
134 |
+
# ==== CHANGE HERE ====
|
135 |
+
# conv_mode = "llava_v1"
|
136 |
+
# model_path = 'LanguageBind/Video-LLaVA-7B'
|
137 |
+
# FIXME!!!
|
138 |
+
|
139 |
+
conv_mode = "llava_v0"
|
140 |
+
model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
|
141 |
+
# model_path = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_VLM_RLAIF_merged'
|
142 |
+
cache_dir = './cache_dir'
|
143 |
+
device = 'cuda'
|
144 |
+
# device = 'cpu'
|
145 |
+
load_8bit = True
|
146 |
+
load_4bit = False
|
147 |
+
dtype = torch.float16
|
148 |
+
# =============
|
149 |
+
|
150 |
+
handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
|
151 |
+
# handler.model.to(dtype=dtype)
|
152 |
+
if not os.path.exists("temp"):
|
153 |
+
os.makedirs("temp")
|
154 |
+
|
155 |
+
app = FastAPI()
|
156 |
+
|
157 |
+
|
158 |
+
textbox = gr.Textbox(
|
159 |
+
show_label=False, placeholder="Enter text and press ENTER", container=False
|
160 |
+
)
|
161 |
+
with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
|
162 |
+
gr.Markdown(title_markdown)
|
163 |
+
state = gr.State()
|
164 |
+
state_ = gr.State()
|
165 |
+
first_run = gr.State()
|
166 |
+
images_tensor = gr.State()
|
167 |
+
|
168 |
+
# image1 = gr.Image(label="Input Image", type="filepath")
|
169 |
+
with gr.Row():
|
170 |
+
with gr.Column(scale=3):
|
171 |
+
video = gr.Video(label="Input Video")
|
172 |
+
|
173 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
174 |
+
gr.Examples(
|
175 |
+
examples=[
|
176 |
+
[
|
177 |
+
f"{cur_dir}/examples/sample_demo_1.mp4",
|
178 |
+
"Why is this video funny?",
|
179 |
+
],
|
180 |
+
[
|
181 |
+
f"{cur_dir}/examples/sample_demo_3.mp4",
|
182 |
+
"Can you identify any safety hazards in this video?"
|
183 |
+
],
|
184 |
+
[
|
185 |
+
f"{cur_dir}/examples/sample_demo_9.mp4",
|
186 |
+
"Describe the video.",
|
187 |
+
],
|
188 |
+
[
|
189 |
+
f"{cur_dir}/examples/sample_demo_22.mp4",
|
190 |
+
"Describe the activity in the video.",
|
191 |
+
],
|
192 |
+
],
|
193 |
+
inputs=[video, textbox],
|
194 |
+
)
|
195 |
+
|
196 |
+
with gr.Column(scale=7):
|
197 |
+
chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Column(scale=8):
|
200 |
+
textbox.render()
|
201 |
+
with gr.Column(scale=1, min_width=50):
|
202 |
+
submit_btn = gr.Button(
|
203 |
+
value="Send", variant="primary", interactive=True
|
204 |
+
)
|
205 |
+
with gr.Row(elem_id="buttons") as button_row:
|
206 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
|
207 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
|
208 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
|
209 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
210 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
|
211 |
+
# clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
|
212 |
+
|
213 |
+
gr.Markdown(tos_markdown)
|
214 |
+
gr.Markdown(learn_more_markdown)
|
215 |
+
|
216 |
+
submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
|
217 |
+
[state, state_, chatbot, first_run, textbox, images_tensor, video])
|
218 |
+
# submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
|
219 |
+
# [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
220 |
+
|
221 |
+
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
|
222 |
+
generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
223 |
+
# generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
|
224 |
+
|
225 |
+
# clear_btn.click(clear_history, [state, state_],
|
226 |
+
# [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
|
227 |
+
# [video, textbox, first_run, state, state_, chatbot, images_tensor])
|
228 |
+
|
229 |
+
# app = gr.mount_gradio_app(app, demo, path="/")
|
230 |
+
demo.launch(share=True)
|
231 |
+
# demo.launch()
|
232 |
+
|
233 |
+
# uvicorn videollava.serve.gradio_web_server:app
|
234 |
+
# python -m videollava.serve.gradio_web_server
|
model_worker.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import uuid
|
10 |
+
|
11 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
+
from fastapi.responses import StreamingResponse
|
13 |
+
import requests
|
14 |
+
import torch
|
15 |
+
import uvicorn
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
from videollava.constants import WORKER_HEART_BEAT_INTERVAL
|
19 |
+
from videollava.utils import (build_logger, server_error_msg,
|
20 |
+
pretty_print_semaphore)
|
21 |
+
from videollava.model.builder import load_pretrained_model
|
22 |
+
from videollava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
|
23 |
+
from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
+
from transformers import TextIteratorStreamer
|
25 |
+
from threading import Thread
|
26 |
+
|
27 |
+
|
28 |
+
GB = 1 << 30
|
29 |
+
|
30 |
+
worker_id = str(uuid.uuid4())[:6]
|
31 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
32 |
+
global_counter = 0
|
33 |
+
|
34 |
+
model_semaphore = None
|
35 |
+
|
36 |
+
|
37 |
+
def heart_beat_worker(controller):
|
38 |
+
|
39 |
+
while True:
|
40 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
41 |
+
controller.send_heart_beat()
|
42 |
+
|
43 |
+
|
44 |
+
class ModelWorker:
|
45 |
+
def __init__(self, controller_addr, worker_addr,
|
46 |
+
worker_id, no_register,
|
47 |
+
model_path, model_base, model_name,
|
48 |
+
load_8bit, load_4bit, device):
|
49 |
+
self.controller_addr = controller_addr
|
50 |
+
self.worker_addr = worker_addr
|
51 |
+
self.worker_id = worker_id
|
52 |
+
if model_path.endswith("/"):
|
53 |
+
model_path = model_path[:-1]
|
54 |
+
if model_name is None:
|
55 |
+
model_paths = model_path.split("/")
|
56 |
+
if model_paths[-1].startswith('checkpoint-'):
|
57 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
58 |
+
else:
|
59 |
+
self.model_name = model_paths[-1]
|
60 |
+
else:
|
61 |
+
self.model_name = model_name
|
62 |
+
|
63 |
+
self.device = device
|
64 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
65 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
66 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
67 |
+
self.is_multimodal = 'llava' in self.model_name.lower()
|
68 |
+
|
69 |
+
if not no_register:
|
70 |
+
self.register_to_controller()
|
71 |
+
self.heart_beat_thread = threading.Thread(
|
72 |
+
target=heart_beat_worker, args=(self,))
|
73 |
+
self.heart_beat_thread.start()
|
74 |
+
|
75 |
+
def register_to_controller(self):
|
76 |
+
logger.info("Register to controller")
|
77 |
+
|
78 |
+
url = self.controller_addr + "/register_worker"
|
79 |
+
data = {
|
80 |
+
"worker_name": self.worker_addr,
|
81 |
+
"check_heart_beat": True,
|
82 |
+
"worker_status": self.get_status()
|
83 |
+
}
|
84 |
+
r = requests.post(url, json=data)
|
85 |
+
assert r.status_code == 200
|
86 |
+
|
87 |
+
def send_heart_beat(self):
|
88 |
+
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
89 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
90 |
+
f"global_counter: {global_counter}")
|
91 |
+
|
92 |
+
url = self.controller_addr + "/receive_heart_beat"
|
93 |
+
|
94 |
+
while True:
|
95 |
+
try:
|
96 |
+
ret = requests.post(url, json={
|
97 |
+
"worker_name": self.worker_addr,
|
98 |
+
"queue_length": self.get_queue_length()}, timeout=5)
|
99 |
+
exist = ret.json()["exist"]
|
100 |
+
break
|
101 |
+
except requests.exceptions.RequestException as e:
|
102 |
+
logger.error(f"heart beat error: {e}")
|
103 |
+
time.sleep(5)
|
104 |
+
|
105 |
+
if not exist:
|
106 |
+
self.register_to_controller()
|
107 |
+
|
108 |
+
def get_queue_length(self):
|
109 |
+
if model_semaphore is None:
|
110 |
+
return 0
|
111 |
+
else:
|
112 |
+
return args.limit_model_concurrency - model_semaphore._value + (len(
|
113 |
+
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
114 |
+
|
115 |
+
def get_status(self):
|
116 |
+
return {
|
117 |
+
"model_names": [self.model_name],
|
118 |
+
"speed": 1,
|
119 |
+
"queue_length": self.get_queue_length(),
|
120 |
+
}
|
121 |
+
|
122 |
+
@torch.inference_mode()
|
123 |
+
def generate_stream(self, params):
|
124 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
125 |
+
|
126 |
+
prompt = params["prompt"]
|
127 |
+
ori_prompt = prompt
|
128 |
+
images = params.get("images", None)
|
129 |
+
num_image_tokens = 0
|
130 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
131 |
+
if len(images) > 0:
|
132 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
133 |
+
raise ValueError("Number of images does not match number of <image> tokens in prompt")
|
134 |
+
|
135 |
+
images = [load_image_from_base64(image) for image in images]
|
136 |
+
images = process_images(images, image_processor, model.config)
|
137 |
+
|
138 |
+
if type(images) is list:
|
139 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
140 |
+
else:
|
141 |
+
images = images.to(self.model.device, dtype=torch.float16)
|
142 |
+
|
143 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
144 |
+
if getattr(self.model.config, 'mm_use_im_start_end', False):
|
145 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
146 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
147 |
+
|
148 |
+
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
|
149 |
+
else:
|
150 |
+
images = None
|
151 |
+
image_args = {"images": images}
|
152 |
+
else:
|
153 |
+
images = None
|
154 |
+
image_args = {}
|
155 |
+
|
156 |
+
temperature = float(params.get("temperature", 1.0))
|
157 |
+
top_p = float(params.get("top_p", 1.0))
|
158 |
+
max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
|
159 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
160 |
+
stop_str = params.get("stop", None)
|
161 |
+
do_sample = True if temperature > 0.001 else False
|
162 |
+
|
163 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
164 |
+
keywords = [stop_str]
|
165 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
166 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
167 |
+
|
168 |
+
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
169 |
+
|
170 |
+
if max_new_tokens < 1:
|
171 |
+
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
172 |
+
return
|
173 |
+
|
174 |
+
thread = Thread(target=model.generate, kwargs=dict(
|
175 |
+
inputs=input_ids,
|
176 |
+
do_sample=do_sample,
|
177 |
+
temperature=temperature,
|
178 |
+
top_p=top_p,
|
179 |
+
max_new_tokens=max_new_tokens,
|
180 |
+
streamer=streamer,
|
181 |
+
stopping_criteria=[stopping_criteria],
|
182 |
+
use_cache=True,
|
183 |
+
**image_args
|
184 |
+
))
|
185 |
+
thread.start()
|
186 |
+
|
187 |
+
generated_text = ori_prompt
|
188 |
+
for new_text in streamer:
|
189 |
+
generated_text += new_text
|
190 |
+
if generated_text.endswith(stop_str):
|
191 |
+
generated_text = generated_text[:-len(stop_str)]
|
192 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
193 |
+
|
194 |
+
def generate_stream_gate(self, params):
|
195 |
+
try:
|
196 |
+
for x in self.generate_stream(params):
|
197 |
+
yield x
|
198 |
+
except ValueError as e:
|
199 |
+
print("Caught ValueError:", e)
|
200 |
+
ret = {
|
201 |
+
"text": server_error_msg,
|
202 |
+
"error_code": 1,
|
203 |
+
}
|
204 |
+
yield json.dumps(ret).encode() + b"\0"
|
205 |
+
except torch.cuda.CudaError as e:
|
206 |
+
print("Caught torch.cuda.CudaError:", e)
|
207 |
+
ret = {
|
208 |
+
"text": server_error_msg,
|
209 |
+
"error_code": 1,
|
210 |
+
}
|
211 |
+
yield json.dumps(ret).encode() + b"\0"
|
212 |
+
except Exception as e:
|
213 |
+
print("Caught Unknown Error", e)
|
214 |
+
ret = {
|
215 |
+
"text": server_error_msg,
|
216 |
+
"error_code": 1,
|
217 |
+
}
|
218 |
+
yield json.dumps(ret).encode() + b"\0"
|
219 |
+
|
220 |
+
|
221 |
+
app = FastAPI()
|
222 |
+
|
223 |
+
|
224 |
+
def release_model_semaphore(fn=None):
|
225 |
+
model_semaphore.release()
|
226 |
+
if fn is not None:
|
227 |
+
fn()
|
228 |
+
|
229 |
+
|
230 |
+
@app.post("/worker_generate_stream")
|
231 |
+
async def generate_stream(request: Request):
|
232 |
+
global model_semaphore, global_counter
|
233 |
+
global_counter += 1
|
234 |
+
params = await request.json()
|
235 |
+
|
236 |
+
if model_semaphore is None:
|
237 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
238 |
+
await model_semaphore.acquire()
|
239 |
+
worker.send_heart_beat()
|
240 |
+
generator = worker.generate_stream_gate(params)
|
241 |
+
background_tasks = BackgroundTasks()
|
242 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
243 |
+
return StreamingResponse(generator, background=background_tasks)
|
244 |
+
|
245 |
+
|
246 |
+
@app.post("/worker_get_status")
|
247 |
+
async def get_status(request: Request):
|
248 |
+
return worker.get_status()
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == "__main__":
|
252 |
+
parser = argparse.ArgumentParser()
|
253 |
+
parser.add_argument("--host", type=str, default="localhost")
|
254 |
+
parser.add_argument("--port", type=int, default=21002)
|
255 |
+
parser.add_argument("--worker-address", type=str,
|
256 |
+
default="http://localhost:21002")
|
257 |
+
parser.add_argument("--controller-address", type=str,
|
258 |
+
default="http://localhost:21001")
|
259 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
260 |
+
parser.add_argument("--model-base", type=str, default=None)
|
261 |
+
parser.add_argument("--model-name", type=str)
|
262 |
+
parser.add_argument("--device", type=str, default="cuda")
|
263 |
+
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
264 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
265 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
266 |
+
parser.add_argument("--no-register", action="store_true")
|
267 |
+
parser.add_argument("--load-8bit", action="store_true")
|
268 |
+
parser.add_argument("--load-4bit", action="store_true")
|
269 |
+
args = parser.parse_args()
|
270 |
+
logger.info(f"args: {args}")
|
271 |
+
|
272 |
+
if args.multi_modal:
|
273 |
+
logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
|
274 |
+
|
275 |
+
worker = ModelWorker(args.controller_address,
|
276 |
+
args.worker_address,
|
277 |
+
worker_id,
|
278 |
+
args.no_register,
|
279 |
+
args.model_path,
|
280 |
+
args.model_base,
|
281 |
+
args.model_name,
|
282 |
+
args.load_8bit,
|
283 |
+
args.load_4bit,
|
284 |
+
args.device)
|
285 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
processing_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import TextStreamer
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import base64
|
10 |
+
from PIL import Image
|
11 |
+
from io import BytesIO
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from torchvision.transforms import Compose, Lambda, ToTensor
|
14 |
+
from torchvision import transforms
|
15 |
+
from transformers import ProcessorMixin, BatchEncoding
|
16 |
+
from transformers.image_processing_utils import BatchFeature
|
17 |
+
from pytorchvideo.data.encoded_video import EncodedVideo
|
18 |
+
from torchvision.transforms import Compose, Lambda, ToTensor
|
19 |
+
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo
|
20 |
+
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
|
21 |
+
|
22 |
+
|
23 |
+
def load_frames(frames_dir):
|
24 |
+
results = []
|
25 |
+
frame_names = os.listdir(frames_dir)
|
26 |
+
frame_names.sort()
|
27 |
+
for frame_name in frame_names:
|
28 |
+
image_path = f"{frames_dir}/{frame_name}"
|
29 |
+
results.append(image_path)
|
30 |
+
return results
|
31 |
+
|
32 |
+
def sample_frames(frames, num_segments):
|
33 |
+
duration = len(frames)
|
34 |
+
frame_id_array = np.linspace(0, duration-1, num_segments, dtype=int)
|
35 |
+
frame_id_list = frame_id_array.tolist()
|
36 |
+
|
37 |
+
sampled_frames = []
|
38 |
+
for frame_idx in frame_id_list:
|
39 |
+
single_frame_path = frames[frame_idx]
|
40 |
+
sampled_frames.append(single_frame_path)
|
41 |
+
return sampled_frames
|
42 |
+
|
43 |
+
|
44 |
+
class VideoProcessor:
|
45 |
+
def __init__(self, image_transform):
|
46 |
+
self.image_transform = image_transform
|
47 |
+
|
48 |
+
def __call__(self, video_path, transform=None,
|
49 |
+
video_decode_backend='opencv',
|
50 |
+
clip_start_sec=0.0, clip_end_sec=None,
|
51 |
+
num_frames=50, **kwargs):
|
52 |
+
if transform is None: transform = self.image_transform
|
53 |
+
if video_decode_backend == 'pytorchvideo':
|
54 |
+
# decord pyav
|
55 |
+
video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False)
|
56 |
+
duration = video.duration
|
57 |
+
start_sec = clip_start_sec # secs
|
58 |
+
end_sec = clip_end_sec if clip_end_sec is not None else duration # secs
|
59 |
+
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
|
60 |
+
video_outputs = transform(video_data)
|
61 |
+
|
62 |
+
elif video_decode_backend == 'decord':
|
63 |
+
import decord
|
64 |
+
from decord import VideoReader, cpu
|
65 |
+
decord.bridge.set_bridge('torch')
|
66 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
67 |
+
ori_duration = len(decord_vr)
|
68 |
+
# frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
|
69 |
+
fps_vid = decord_vr.get_avg_fps()
|
70 |
+
valid_duration = min(int(fps_vid * 10), ori_duration)
|
71 |
+
frame_id_list = np.linspace(0, valid_duration-1, num_frames, dtype=int)
|
72 |
+
video_data = decord_vr.get_batch(frame_id_list)
|
73 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
74 |
+
video_outputs = transform(video_data)
|
75 |
+
|
76 |
+
elif video_decode_backend == 'opencv':
|
77 |
+
import cv2
|
78 |
+
cv2_vr = cv2.VideoCapture(video_path)
|
79 |
+
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
|
80 |
+
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
|
81 |
+
|
82 |
+
video_data = []
|
83 |
+
for frame_idx in frame_id_list:
|
84 |
+
cv2_vr.set(1, frame_idx)
|
85 |
+
_, frame = cv2_vr.read()
|
86 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
87 |
+
video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
|
88 |
+
cv2_vr.release()
|
89 |
+
video_data = torch.stack(video_data, dim=1)
|
90 |
+
video_outputs = transform(video_data)
|
91 |
+
|
92 |
+
elif video_decode_backend == 'frames':
|
93 |
+
# FIXME does not input start and end clip timestamps. Require duration info to deal with.
|
94 |
+
frames = load_frames(video_path)
|
95 |
+
frames = sample_frames(frames, num_frames)
|
96 |
+
to_tensor = ToTensor()
|
97 |
+
video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
|
98 |
+
else:
|
99 |
+
raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, frames)')
|
register_worker.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Manually register workers.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import requests
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--controller-address", type=str)
|
15 |
+
parser.add_argument("--worker-name", type=str)
|
16 |
+
parser.add_argument("--check-heart-beat", action="store_true")
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
url = args.controller_address + "/register_worker"
|
20 |
+
data = {
|
21 |
+
"worker_name": args.worker_name,
|
22 |
+
"check_heart_beat": args.check_heart_beat,
|
23 |
+
"worker_status": None,
|
24 |
+
}
|
25 |
+
r = requests.post(url, json=data)
|
26 |
+
assert r.status_code == 200
|
test_message.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
from videollava.conversation import default_conversation
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
if args.worker_address:
|
11 |
+
worker_addr = args.worker_address
|
12 |
+
else:
|
13 |
+
controller_addr = args.controller_address
|
14 |
+
ret = requests.post(controller_addr + "/refresh_all_workers")
|
15 |
+
ret = requests.post(controller_addr + "/list_models")
|
16 |
+
models = ret.json()["models"]
|
17 |
+
models.sort()
|
18 |
+
print(f"Models: {models}")
|
19 |
+
|
20 |
+
ret = requests.post(controller_addr + "/get_worker_address",
|
21 |
+
json={"model": args.model_name})
|
22 |
+
worker_addr = ret.json()["address"]
|
23 |
+
print(f"worker_addr: {worker_addr}")
|
24 |
+
|
25 |
+
if worker_addr == "":
|
26 |
+
return
|
27 |
+
|
28 |
+
conv = default_conversation.copy()
|
29 |
+
conv.append_message(conv.roles[0], args.message)
|
30 |
+
prompt = conv.get_prompt()
|
31 |
+
|
32 |
+
headers = {"User-Agent": "LLaVA Client"}
|
33 |
+
pload = {
|
34 |
+
"model": args.model_name,
|
35 |
+
"prompt": prompt,
|
36 |
+
"max_new_tokens": args.max_new_tokens,
|
37 |
+
"temperature": 0.7,
|
38 |
+
"stop": conv.sep,
|
39 |
+
}
|
40 |
+
response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
|
41 |
+
json=pload, stream=True)
|
42 |
+
|
43 |
+
print(prompt.replace(conv.sep, "\n"), end="")
|
44 |
+
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
45 |
+
if chunk:
|
46 |
+
data = json.loads(chunk.decode("utf-8"))
|
47 |
+
output = data["text"].split(conv.sep)[-1]
|
48 |
+
print(output, end="\r")
|
49 |
+
print("")
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
55 |
+
parser.add_argument("--worker-address", type=str)
|
56 |
+
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
57 |
+
parser.add_argument("--max-new-tokens", type=int, default=32)
|
58 |
+
parser.add_argument("--message", type=str, default=
|
59 |
+
"Tell me a story with more than 1000 words.")
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def load_image(image_file):
|
8 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
9 |
+
response = requests.get(image_file)
|
10 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
11 |
+
else:
|
12 |
+
image = Image.open(image_file).convert('RGB')
|
13 |
+
return image
|
14 |
+
|
15 |
+
video_ext = ['.mp4', '.mov', '.mkv', '.avi']
|
16 |
+
image_ext = ['.jpg', '.png', '.bmp', '.jpeg']
|