import shutil import subprocess import torch import gradio as gr from fastapi import FastAPI import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' from PIL import Image import tempfile from decord import VideoReader, cpu from transformers import TextStreamer import argparse import sys sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation")) from llava.constants import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle, Conversation from llava.mm_utils import process_images from Evaluation.infer_utils import load_video_into_frames from serve.utils import load_image, image_ext, video_ext from serve.gradio_utils import Chat, tos_markdown, learn_more_markdown, title_markdown, block_css def save_image_to_local(image): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg') image = Image.open(image) image.save(filename) # print(filename) return filename def save_video_to_local(video_path): filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4') shutil.copyfile(video_path, filename) return filename def generate(video, textbox_in, first_run, state, state_, images_tensor, num_frames=50): # ======= manually clear the conversation # state = conv_templates[conv_mode].copy() # state_ = conv_templates[conv_mode].copy() # # ======= flag = 1 if not textbox_in: if len(state_.messages) > 0: textbox_in = state_.messages[-1][1] state_.messages.pop(-1) flag = 0 else: return "Please enter instruction" # else: # if state is not None and state_ is not None: # # reset conversations # state.messages = [] # state_.messages = [] print("Video", video) # 잘 들어감 print("Images_tensor", images_tensor) # None print("Textbox_IN", textbox_in) # 잘 들어감 print("State", state) # None print("State_", state_) # None # print(len(state_.messages)) video = video if video else "none" if type(state) is not Conversation: state = conv_templates[conv_mode].copy() state_ = conv_templates[conv_mode].copy() images_tensor = [] first_run = False if len(state.messages) > 0 else True text_en_in = textbox_in.replace("picture", "image") image_processor = handler.image_processor assert os.path.exists(video) if os.path.splitext(video)[-1].lower() in video_ext: # video extension video_decode_backend = 'opencv' elif os.path.splitext(os.listdir(video)[0]).lower() in image_ext: # frames folder video_decode_backend = 'frames' else: raise ValueError(f'Support video of {video_ext} and frames of {image_ext}, but found {os.path.splitext(video)[-1].lower()}') frames = load_video_into_frames(video, video_decode_backend=video_decode_backend, num_frames=num_frames) tensor = process_images(frames, image_processor, argparse.Namespace(image_aspect_ratio='pad')) # tensor = video_processor(video, return_tensors='pt')['pixel_values'][0] # print(tensor.shape) tensor = tensor.to(handler.model.device, dtype=dtype) # images_tensor.append(tensor) images_tensor = tensor if handler.model.config.mm_use_im_start_end: text_en_in = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + text_en_in else: text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_) state_.messages[-1] = (state_.roles[1], text_en_out) text_en_out = text_en_out.split('#')[0] textbox_out = text_en_out show_images = "" if os.path.exists(video): filename = save_video_to_local(video) show_images += f'' if flag: state.append_message(state.roles[0], textbox_in + "\n" + show_images) state.append_message(state.roles[1], textbox_out) return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, \ gr.update(value=video if os.path.exists(video) else None, interactive=True)) def regenerate(state, state_): state.messages.pop(-1) state_.messages.pop(-1) if len(state.messages) > 0: return state, state_, state.to_gradio_chatbot(), False return (state, state_, state.to_gradio_chatbot(), True) def clear_history(state, state_): state = conv_templates[conv_mode].copy() state_ = conv_templates[conv_mode].copy() return (gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), \ gr.update(value=None, interactive=True), \ True, state, state_, state.to_gradio_chatbot(), []) # ==== CHANGE HERE ==== # conv_mode = "llava_v1" # model_path = 'LanguageBind/Video-LLaVA-7B' # FIXME!!! conv_mode = "llava_v0" model_path = 'SNUMPR/vlm_rlaif_video_llava_7b' # model_path = '/dataset/yura/vlm-rlaif/pretrained/final_models/Video_LLaVA_VLM_RLAIF_merged' cache_dir = './cache_dir' device = 'cuda' # device = 'cpu' load_8bit = True load_4bit = False dtype = torch.float16 # ============= handler = Chat(model_path, conv_mode=conv_mode, load_8bit=load_8bit, load_4bit=load_8bit, device=device, cache_dir=cache_dir) # handler.model.to(dtype=dtype) if not os.path.exists("temp"): os.makedirs("temp") app = FastAPI() textbox = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", container=False ) with gr.Blocks(title='VLM-RLAIF', theme=gr.themes.Default(), css=block_css) as demo: gr.Markdown(title_markdown) state = gr.State() state_ = gr.State() first_run = gr.State() images_tensor = gr.State() # image1 = gr.Image(label="Input Image", type="filepath") with gr.Row(): with gr.Column(scale=3): video = gr.Video(label="Input Video") cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples( examples=[ [ f"{cur_dir}/examples/sample_demo_1.mp4", "Why is this video funny?", ], [ f"{cur_dir}/examples/sample_demo_3.mp4", "Can you identify any safety hazards in this video?" ], [ f"{cur_dir}/examples/sample_demo_9.mp4", "Describe the video.", ], [ f"{cur_dir}/examples/sample_demo_22.mp4", "Describe the activity in the video.", ], ], inputs=[video, textbox], ) with gr.Column(scale=7): chatbot = gr.Chatbot(label="VLM_RLAIF", bubble_full_width=True).style(height=750) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button( value="Send", variant="primary", interactive=True ) with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=True) downvote_btn = gr.Button(value="👎 Downvote", interactive=True) flag_btn = gr.Button(value="⚠️ Flag", interactive=True) # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) # clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) # submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], # [state, state_, chatbot, first_run, textbox, images_tensor, video]) regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then( generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) # generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video]) # clear_btn.click(clear_history, [state, state_], # [image1, video, textbox, first_run, state, state_, chatbot, images_tensor]) # [video, textbox, first_run, state, state_, chatbot, images_tensor]) # app = gr.mount_gradio_app(app, demo, path="/") demo.launch(share=True) # demo.launch() # uvicorn videollava.serve.gradio_web_server:app # python -m videollava.serve.gradio_web_server