amstrongzyf commited on
Commit
986b2b2
1 Parent(s): 38c55e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -60
app.py CHANGED
@@ -1,112 +1,145 @@
1
  import time
2
  from threading import Thread
 
3
 
4
  import gradio as gr
5
  import torch
6
- from PIL import Image
7
- from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer
8
-
9
- import spaces
10
- import argparse
11
 
12
  from llava_llama3.model.builder import load_pretrained_model
13
- from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
14
- from llava_llama3.conversation import conv_templates, SeparatorStyle
15
- from llava_llama3.utils import disable_torch_init
16
- from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
17
  from llava_llama3.serve.cli import chat_llava
18
 
19
- import requests
20
- from io import BytesIO
21
- import base64
22
  import os
23
- import glob
24
- import pandas as pd
25
- from tqdm import tqdm
26
- import json
27
 
 
28
  root_path = os.path.dirname(os.path.abspath(__file__))
29
  print(f'\033[92m{root_path}\033[0m')
30
  os.environ['GRADIO_TEMP_DIR'] = root_path
31
 
32
- parser = argparse.ArgumentParser()
33
- parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
34
- parser.add_argument("--device", type=str, default="cuda")
35
- parser.add_argument("--conv-mode", type=str, default="llama_3")
36
- parser.add_argument("--temperature", type=float, default=0.7)
37
- parser.add_argument("--max-new-tokens", type=int, default=512)
38
- parser.add_argument("--load-8bit", action="store_true")
39
- parser.add_argument("--load-4bit", action="store_true")
40
- args = parser.parse_args()
 
41
 
42
- # Load model
43
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
44
- args.model_path,
45
  None,
46
  'llava_llama3',
47
- args.load_8bit,
48
- args.load_4bit,
49
- device=args.device)
 
50
 
51
- @spaces.GPU
52
- def bot_streaming(message, history):
53
- print(message)
54
  image_file = None
55
  if message["files"]:
56
- if type(message["files"][-1]) == dict:
57
  image_file = message["files"][-1]["path"]
58
  else:
59
  image_file = message["files"][-1]
60
  else:
61
  for hist in history:
62
- if type(hist[0]) == tuple:
63
  image_file = hist[0][0]
64
 
65
  if image_file is None:
66
  gr.Error("You need to upload an image for LLaVA to work.")
67
  return
68
 
 
 
 
 
69
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
70
  def generate():
71
  print('\033[92mRunning chat\033[0m')
72
- output = chat_llava(
73
- args=args,
74
- image_file=image_file,
75
- text=message['text'],
76
- tokenizer=tokenizer,
77
- model=llava_model,
78
- image_processor=image_processor,
79
- context_len=context_len,
80
- streamer=streamer)
81
- return output
82
 
83
  thread = Thread(target=generate)
84
  thread.start()
85
- # thread.join()
86
 
87
  buffer = ""
88
- # output = generate()
89
  for new_text in streamer:
90
  buffer += new_text
91
- generated_text_without_prompt = buffer
92
  time.sleep(0.06)
93
- yield generated_text_without_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- chatbot = gr.Chatbot(scale=1)
96
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
97
- with gr.Blocks(fill_height=True) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  gr.ChatInterface(
99
  fn=bot_streaming,
100
- title="FinLLaVA Demo",
 
 
101
  examples=[
102
- {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
103
  ],
 
104
  description="",
105
- stop_btn="Stop Generation",
106
- multimodal=True,
107
- textbox=chat_input,
108
- chatbot=chatbot,
109
  )
110
 
111
- demo.queue(api_open=False)
112
- demo.launch(show_api=False, share=False)
 
1
  import time
2
  from threading import Thread
3
+ import copy
4
 
5
  import gradio as gr
6
  import torch
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
 
 
 
 
8
 
9
  from llava_llama3.model.builder import load_pretrained_model
 
 
 
 
10
  from llava_llama3.serve.cli import chat_llava
11
 
 
 
 
12
  import os
13
+ import argparse
 
 
 
14
 
15
+ # Set environment variables
16
  root_path = os.path.dirname(os.path.abspath(__file__))
17
  print(f'\033[92m{root_path}\033[0m')
18
  os.environ['GRADIO_TEMP_DIR'] = root_path
19
 
20
+ # Create a default arguments object
21
+ default_args = argparse.Namespace(
22
+ model_path="TheFinAI/FinLLaVA",
23
+ device="cuda",
24
+ conv_mode="llama_3",
25
+ temperature=0.7,
26
+ max_new_tokens=512,
27
+ load_8bit=False,
28
+ load_4bit=False
29
+ )
30
 
31
+ # Load the model
32
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
33
+ default_args.model_path,
34
  None,
35
  'llava_llama3',
36
+ default_args.load_8bit,
37
+ default_args.load_4bit,
38
+ device=default_args.device
39
+ )
40
 
41
+ def bot_streaming(message, history, temperature, max_new_tokens):
 
 
42
  image_file = None
43
  if message["files"]:
44
+ if isinstance(message["files"][-1], dict):
45
  image_file = message["files"][-1]["path"]
46
  else:
47
  image_file = message["files"][-1]
48
  else:
49
  for hist in history:
50
+ if isinstance(hist[0], tuple):
51
  image_file = hist[0][0]
52
 
53
  if image_file is None:
54
  gr.Error("You need to upload an image for LLaVA to work.")
55
  return
56
 
57
+ args = copy.deepcopy(default_args)
58
+ args.temperature = temperature
59
+ args.max_new_tokens = max_new_tokens
60
+
61
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
62
+
63
  def generate():
64
  print('\033[92mRunning chat\033[0m')
65
+ return chat_llava(
66
+ args=args,
67
+ image_file=image_file,
68
+ text=message['text'],
69
+ tokenizer=tokenizer,
70
+ model=llava_model,
71
+ image_processor=image_processor,
72
+ context_len=context_len,
73
+ streamer=streamer
74
+ )
75
 
76
  thread = Thread(target=generate)
77
  thread.start()
 
78
 
79
  buffer = ""
 
80
  for new_text in streamer:
81
  buffer += new_text
 
82
  time.sleep(0.06)
83
+ yield buffer
84
+
85
+ # Define CSS styles
86
+ css = """
87
+ body {
88
+ font-family: Arial, sans-serif;
89
+ }
90
+ .gradio-container {
91
+ max-width: 800px;
92
+ margin: auto;
93
+ }
94
+ .chatbot {
95
+ height: 400px;
96
+ overflow-y: auto;
97
+ }
98
+ """
99
 
100
+ # Create interface using gr.Blocks
101
+ with gr.Blocks(css=css) as demo:
102
+ gr.Markdown("# FinLLaVA Demo")
103
+
104
+ chatbot = gr.Chatbot(scale=1)
105
+ chat_input = gr.MultimodalTextbox(
106
+ interactive=True,
107
+ file_types=["image"],
108
+ placeholder="Enter message or upload file...",
109
+ show_label=False
110
+ )
111
+
112
+ with gr.Accordion("Advanced Settings", open=False):
113
+ temperature = gr.Slider(
114
+ label="Temperature",
115
+ minimum=0.1,
116
+ maximum=2.0,
117
+ step=0.1,
118
+ value=default_args.temperature
119
+ )
120
+ max_new_tokens = gr.Slider(
121
+ label="Max New Tokens",
122
+ minimum=1,
123
+ maximum=1024,
124
+ step=1,
125
+ value=default_args.max_new_tokens
126
+ )
127
+
128
  gr.ChatInterface(
129
  fn=bot_streaming,
130
+ chatbot=chatbot,
131
+ textbox=chat_input,
132
+ additional_inputs=[temperature, max_new_tokens],
133
  examples=[
134
+ {"text": "What's in this image?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
135
  ],
136
+ title="",
137
  description="",
138
+ theme="soft",
139
+ retry_btn="Retry",
140
+ undo_btn="Undo",
141
+ clear_btn="Clear",
142
  )
143
 
144
+ if __name__ == "__main__":
145
+ demo.queue(api_open=False).launch(share=False, debug=True)