vlm-rlaif-demo / gradio_web_server copy.py
SNUMPR's picture
Upload folder using huggingface_hub
598d165 verified
raw
history blame
9.17 kB
import shutil
import subprocess
import torch
import gradio as gr
from fastapi import FastAPI
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from PIL import Image
import tempfile
from decord import VideoReader, cpu
from transformers import TextStreamer
import argparse
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation"))
from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle, Conversation
from llava.mm_utils import process_images
from Evaluation.infer_utils import load_video_into_frames
from serve.utils import load_image, image_ext, video_ext
from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css
def save_image_to_local(image):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
image = Image.open(image)
image.save(filename)
# print(filename)
return filename
def save_video_to_local(video_path):
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
shutil.copyfile(video_path, filename)
return filename
def generate(image1, video, textbox_in, first_run, state, state_, images_tensor, num_frames=50):
# ======= manually clear the conversation
# state = conv_templates[conv_mode].copy()
# state_ = conv_templates[conv_mode].copy()
# # =======
flag = 1
if not textbox_in:
if len(state_.messages) > 0:
textbox_in = state_.messages[-1][1]
state_.messages.pop(-1)
flag = 0
else:
return "Please enter instruction"
print("Video", video) # 잘 듀어감
print("Images_tensor", images_tensor) # None
print("Textbox_IN", textbox_in) # 잘 듀어감
print("State", state) # None
print("State_", state_) # None
# print(len(state_.messages))
video = video if video else "none"
if type(state) is not Conversation:
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
images_tensor = []
first_run = False if len(state.messages) > 0 else True
text_en_in = textbox_in.replace("picture", "image")
image_processor = handler.image_processor
assert os.path.exists(video)
if os.path.splitext(video)[-1].lower() in video_ext: # video extension
video_decode_backend = 'opencv'
elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder
video_decode_backend = 'frames'
else:
raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}')
frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames)
tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad'))
# tensor = video_processor(video, return_tensors='pt')['pixel_values'][0]
# print(tensor.shape)
tensor = tensor.to(handler.model.device, dtype=dtype)
# images_tensor.append(tensor)
images_tensor = tensor
if handler.model.config.mm_use_im_start_end:
text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in
else:
text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in
text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
state_.messages[-1] = (state_.roles[1], text_en_out)
text_en_out = text_en_out.split('#')[0]
textbox_out = text_en_out
show_images = ""
if os.path.exists(video):
filename = save_video_to_local(video)
show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
if flag:
state.append_message(state.roles[0], textbox_in + "\n" + show_images)
state.append_message(state.roles[1], textbox_out)
return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=image1 if os.path.exists(video) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
def regenerate(state, state_):
state.messages.pop(-1)
state_.messages.pop(-1)
if len(state.messages) > 0:
return state, state_, state.to_gradio_chatbot(), False
return (state, state_, state.to_gradio_chatbot(), True)
def clear_history(state, state_):
state = conv_templates[conv_mode].copy()
state_ = conv_templates[conv_mode].copy()
return (gr.update(value=None, interactive=True),
gr.update(value=None, interactive=True), \
gr.update(value=None, interactive=True), \
True, state, state_, state.to_gradio_chatbot(), [])
# ==== CHANGE HERE ====
# conv_mode = "llava_v1"
# model_path = 'LanguageBind/Video-LLaVA-7B'
# FIXME!!!
conv_mode = "llava_v0"
model_path = 'SNUMPR/vlm_rlaif_video_llava_7b'
# model_path = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_VLM_RLAIF_merged'
cache_dir = './cache_dir'
device = 'cuda'
# device = 'cpu'
load_8bit = True
load_4bit = False
dtype = torch.float16
# =============
handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir)
# handler.model.to(dtype=dtype)
if not os.path.exists("temp"):
os.makedirs("temp")
app = FastAPI()
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo:
gr.Markdown(title_markdown)
state = gr.State()
state_ = gr.State()
first_run = gr.State()
images_tensor = gr.State()
image1 = gr.Image(label="Input Image", type="filepath")
with gr.Row():
with gr.Column(scale=3):
video = gr.Video(label="Input Video")
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(
examples=[
[
f"{cur_dir}/examples/sample_demo_1.mp4",
"Why is this video funny?",
],
[
f"{cur_dir}/examples/sample_demo_3.mp4",
"Can you identify any safety hazards in this video?"
],
[
f"{cur_dir}/examples/sample_demo_9.mp4",
"Describe the video.",
],
[
f"{cur_dir}/examples/sample_demo_22.mp4",
"Describe the activity in the video.",
],
],
inputs=[video, textbox],
)
with gr.Column(scale=7):
chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(
value="Send", variant="primary", interactive=True
)
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=True)
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=True)
flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=True)
# clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=True)
gr.Markdown(tos_markdown)
gr.Markdown(learn_more_markdown)
submit_btn.click(generate, [image1, video, textbox, first_run, state, state_, images_tensor],
[state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
# submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor],
# [state, state_, chatbot, first_run, textbox, images_tensor, video])
regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
generate, [image1, video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, image1, video])
# generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
# clear_btn.click(clear_history, [state, state_],
# [image1, video, textbox, first_run, state, state_, chatbot, images_tensor])
# [video, textbox, first_run, state, state_, chatbot, images_tensor])
# app = gr.mount_gradio_app(app, demo, path="/")
# demo.launch(share=True)
demo.launch()
# uvicorn videollava.serve.gradio_web_server:app
# python -m videollava.serve.gradio_web_server