import time from threading import Thread import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer import spaces import argparse from llava_llama3.model.builder import load_pretrained_model from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava_llama3.conversation import conv_templates, SeparatorStyle from llava_llama3.utils import disable_torch_init from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path from llava_llama3.serve.cli import chat_llava import requests from io import BytesIO import base64 import os import glob import pandas as pd from tqdm import tqdm import json root_path = os.path.dirname(os.path.abspath(__file__)) print(f'\033[92m{root_path}\033[0m') os.environ['GRADIO_TEMP_DIR'] = root_path parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--conv-mode", type=str, default="llama_3") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") args = parser.parse_args() # Load model tokenizer, llava_model, image_processor, context_len = load_pretrained_model( args.model_path, None, 'llava_llama3', args.load_8bit, args.load_4bit, device=args.device) @spaces.GPU def bot_streaming(message, history): print(message) image_file = None if message["files"]: if type(message["files"][-1]) == dict: image_file = message["files"][-1]["path"] else: image_file = message["files"][-1] else: for hist in history: if type(hist[0]) == tuple: image_file = hist[0][0] if image_file is None: gr.Error("You need to upload an image for LLaVA to work.") return streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) def generate(): print('\033[92mRunning chat\033[0m') output = chat_llava( args=args, image_file=image_file, text=message['text'], tokenizer=tokenizer, model=llava_model, image_processor=image_processor, context_len=context_len, streamer=streamer) return output thread = Thread(target=generate) thread.start() # thread.join() buffer = "" # output = generate() for new_text in streamer: buffer += new_text generated_text_without_prompt = buffer time.sleep(0.06) yield generated_text_without_prompt chatbot = gr.Chatbot(scale=1) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(fill_height=True) as demo: gr.ChatInterface( fn=bot_streaming, title="FinLLaVA Demo", examples=[ {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]}, ], description="", stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False)