amstrongzyf commited on
Commit
32db94f
1 Parent(s): d5bf1ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -78
app.py CHANGED
@@ -4,108 +4,101 @@ from threading import Thread
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
7
- from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer
 
8
 
9
  import spaces
10
- import argparse
11
-
12
- from llava_llama3.model.builder import load_pretrained_model
13
- from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
14
- from llava_llama3.conversation import conv_templates, SeparatorStyle
15
- from llava_llama3.utils import disable_torch_init
16
- from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
17
- from llava_llama3.serve.cli import chat_llava
18
-
19
- import requests
20
- from io import BytesIO
21
- import base64
22
- import os
23
- import glob
24
- import pandas as pd
25
- from tqdm import tqdm
26
- import json
27
-
28
- root_path = os.path.dirname(os.path.abspath(__file__))
29
- print(f'\033[92m{root_path}\033[0m')
30
- os.environ['GRADIO_TEMP_DIR'] = root_path
31
-
32
- parser = argparse.ArgumentParser()
33
- parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
34
- parser.add_argument("--device", type=str, default="cuda")
35
- parser.add_argument("--conv-mode", type=str, default="llama_3")
36
- parser.add_argument("--temperature", type=float, default=0.7)
37
- parser.add_argument("--max-new-tokens", type=int, default=512)
38
- parser.add_argument("--load-8bit", action="store_true")
39
- parser.add_argument("--load-4bit", action="store_true")
40
- args = parser.parse_args()
41
-
42
- # Load model
43
- tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
44
- args.model_path,
45
- None,
46
- 'llava_llama3',
47
- args.load_8bit,
48
- args.load_4bit,
49
- device=args.device)
50
 
51
  @spaces.GPU
52
  def bot_streaming(message, history):
53
  print(message)
54
- image_file = None
55
  if message["files"]:
 
56
  if type(message["files"][-1]) == dict:
57
- image_file = message["files"][-1]["path"]
58
  else:
59
- image_file = message["files"][-1]
60
  else:
 
 
61
  for hist in history:
62
  if type(hist[0]) == tuple:
63
- image_file = hist[0][0]
64
-
65
- if image_file is None:
 
 
 
 
66
  gr.Error("You need to upload an image for LLaVA to work.")
67
- return
68
-
69
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
70
- def generate():
71
- print('\033[92mRunning chat\033[0m')
72
- output = chat_llava(
73
- args=args,
74
- image_file=image_file,
75
- text=message['text'],
76
- tokenizer=tokenizer,
77
- model=llava_model,
78
- image_processor=image_processor,
79
- context_len=context_len,
80
- streamer=streamer)
81
- return output
82
-
83
- thread = Thread(target=generate)
84
  thread.start()
85
- # thread.join()
 
 
86
 
87
  buffer = ""
88
- # output = generate()
89
  for new_text in streamer:
 
 
 
90
  buffer += new_text
 
 
91
  generated_text_without_prompt = buffer
 
92
  time.sleep(0.06)
 
93
  yield generated_text_without_prompt
94
 
95
- chatbot = gr.Chatbot(scale=1)
 
96
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
97
- with gr.Blocks(fill_height=True) as demo:
98
  gr.ChatInterface(
99
- fn=bot_streaming,
100
- title="FinLLaVA Demo",
101
- examples=[
102
- {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
103
- ],
104
- description="",
105
- stop_btn="Stop Generation",
106
- multimodal=True,
107
- textbox=chat_input,
108
- chatbot=chatbot,
109
  )
110
 
111
  demo.queue(api_open=False)
 
4
  import gradio as gr
5
  import torch
6
  from PIL import Image
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
8
+ from transformers import TextIteratorStreamer
9
 
10
  import spaces
11
+
12
+
13
+ PLACEHOLDER = """
14
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
15
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64ccdc322e592905f922a06e/DDIW0kbWmdOQWwy4XMhwX.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
16
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-Llama-3-8B</h1>
17
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Llava-Llama-3-8b is a LLaVA model fine-tuned from Meta-Llama-3-8B-Instruct and CLIP-ViT-Large-patch14-336 with ShareGPT4V-PT and InternVL-SFT by XTuner</p>
18
+ </div>
19
+ """
20
+
21
+
22
+ model_id = "TheFinAI/FinLLaVA"
23
+
24
+ processor = AutoProcessor.from_pretrained(model_id)
25
+
26
+ model = LlavaForConditionalGeneration.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.float16,
29
+ low_cpu_mem_usage=True,
30
+ )
31
+
32
+ model.to("cuda:0")
33
+ model.generation_config.eos_token_id = 128009
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  @spaces.GPU
37
  def bot_streaming(message, history):
38
  print(message)
 
39
  if message["files"]:
40
+ # message["files"][-1] is a Dict or just a string
41
  if type(message["files"][-1]) == dict:
42
+ image = message["files"][-1]["path"]
43
  else:
44
+ image = message["files"][-1]
45
  else:
46
+ # if there's no image uploaded for this turn, look for images in the past turns
47
+ # kept inside tuples, take the last one
48
  for hist in history:
49
  if type(hist[0]) == tuple:
50
+ image = hist[0][0]
51
+ try:
52
+ if image is None:
53
+ # Handle the case where image is None
54
+ gr.Error("You need to upload an image for LLaVA to work.")
55
+ except NameError:
56
+ # Handle the case where 'image' is not defined at all
57
  gr.Error("You need to upload an image for LLaVA to work.")
58
+
59
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
60
+ # print(f"prompt: {prompt}")
61
+ image = Image.open(image)
62
+ inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
63
+
64
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
65
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
66
+
67
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
68
  thread.start()
69
+
70
+ text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
71
+ # print(f"text_prompt: {text_prompt}")
72
 
73
  buffer = ""
74
+ time.sleep(0.5)
75
  for new_text in streamer:
76
+ # find <|eot_id|> and remove it from the new_text
77
+ if "<|eot_id|>" in new_text:
78
+ new_text = new_text.split("<|eot_id|>")[0]
79
  buffer += new_text
80
+
81
+ # generated_text_without_prompt = buffer[len(text_prompt):]
82
  generated_text_without_prompt = buffer
83
+ # print(generated_text_without_prompt)
84
  time.sleep(0.06)
85
+ # print(f"new_text: {generated_text_without_prompt}")
86
  yield generated_text_without_prompt
87
 
88
+
89
+ chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
90
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
91
+ with gr.Blocks(fill_height=True, ) as demo:
92
  gr.ChatInterface(
93
+ fn=bot_streaming,
94
+ title="LLaVA Llama-3-8B",
95
+ examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
96
+ ],
97
+ description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
98
+ stop_btn="Stop Generation",
99
+ multimodal=True,
100
+ textbox=chat_input,
101
+ chatbot=chatbot,
 
102
  )
103
 
104
  demo.queue(api_open=False)