amstrongzyf commited on
Commit
d5bf1ae
1 Parent(s): 5693cbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -94
app.py CHANGED
@@ -1,146 +1,112 @@
1
  import time
2
  from threading import Thread
3
- import copy
4
 
5
  import gradio as gr
6
  import torch
7
- from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
 
 
 
 
8
 
9
  from llava_llama3.model.builder import load_pretrained_model
 
 
 
 
10
  from llava_llama3.serve.cli import chat_llava
11
 
 
 
 
12
  import os
13
- import argparse
 
 
 
14
 
15
- # Set environment variables
16
  root_path = os.path.dirname(os.path.abspath(__file__))
17
  print(f'\033[92m{root_path}\033[0m')
18
  os.environ['GRADIO_TEMP_DIR'] = root_path
19
 
20
- # Create a default arguments object
21
- default_args = argparse.Namespace(
22
- model_path="TheFinAI/FinLLaVA",
23
- device="cuda",
24
- conv_mode="llama_3",
25
- temperature=0.7,
26
- max_new_tokens=512,
27
- load_8bit=False,
28
- load_4bit=False
29
- )
30
 
31
- # Load the model
32
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
33
- default_args.model_path,
34
  None,
35
  'llava_llama3',
36
- default_args.load_8bit,
37
- default_args.load_4bit,
38
- device=default_args.device
39
- )
40
 
41
- def bot_streaming(message, history, temperature, max_new_tokens):
 
 
42
  image_file = None
43
  if message["files"]:
44
- if isinstance(message["files"][-1], dict):
45
  image_file = message["files"][-1]["path"]
46
  else:
47
  image_file = message["files"][-1]
48
  else:
49
  for hist in history:
50
- if isinstance(hist[0], tuple):
51
  image_file = hist[0][0]
52
 
53
  if image_file is None:
54
  gr.Error("You need to upload an image for LLaVA to work.")
55
  return
56
 
57
- args = copy.deepcopy(default_args)
58
- args.temperature = temperature
59
- args.max_new_tokens = max_new_tokens
60
-
61
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
62
-
63
  def generate():
64
  print('\033[92mRunning chat\033[0m')
65
- return chat_llava(
66
- args=args,
67
- image_file=image_file,
68
- text=message['text'],
69
- tokenizer=tokenizer,
70
- model=llava_model,
71
- image_processor=image_processor,
72
- context_len=context_len,
73
- streamer=streamer
74
- )
75
 
76
  thread = Thread(target=generate)
77
  thread.start()
 
78
 
79
  buffer = ""
 
80
  for new_text in streamer:
81
  buffer += new_text
 
82
  time.sleep(0.06)
83
- yield buffer
84
-
85
- # Define CSS styles
86
- css = """
87
- body {
88
- font-family: Arial, sans-serif;
89
- }
90
- .gradio-container {
91
- max-width: 800px;
92
- margin: auto;
93
- }
94
- .chatbot {
95
- height: 400px;
96
- overflow-y: auto;
97
- }
98
- """
99
 
100
- # Create interface using gr.Blocks
101
- with gr.Blocks(css=css) as demo:
102
- gr.Markdown("# FinLLaVA Demo")
103
-
104
- chatbot = gr.Chatbot(scale=1)
105
- chat_input = gr.MultimodalTextbox(
106
- interactive=True,
107
- file_types=["image"],
108
- placeholder="Enter message or upload file...",
109
- show_label=False
110
- )
111
-
112
- with gr.Accordion("Advanced Settings", open=False):
113
- temperature = gr.Slider(
114
- label="Temperature",
115
- minimum=0.1,
116
- maximum=2.0,
117
- step=0.1,
118
- value=default_args.temperature
119
- )
120
- max_new_tokens = gr.Slider(
121
- label="Max New Tokens",
122
- minimum=1,
123
- maximum=1024,
124
- step=1,
125
- value=default_args.max_new_tokens
126
- )
127
-
128
- chat_interface = gr.ChatInterface(
129
  fn=bot_streaming,
130
- chatbot=chatbot,
131
- textbox=chat_input,
132
- additional_inputs=[temperature, max_new_tokens],
133
  examples=[
134
  {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
135
  ],
136
- title="",
137
  description="",
138
- theme="soft",
139
- retry_btn="Retry",
140
- undo_btn="Undo",
141
- clear_btn="Clear",
142
  )
143
 
144
-
145
- if __name__ == "__main__":
146
- demo.queue(api_open=False).launch(share=False, debug=True)
 
1
  import time
2
  from threading import Thread
 
3
 
4
  import gradio as gr
5
  import torch
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer, TextStreamer
8
+
9
+ import spaces
10
+ import argparse
11
 
12
  from llava_llama3.model.builder import load_pretrained_model
13
+ from llava_llama3.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
14
+ from llava_llama3.conversation import conv_templates, SeparatorStyle
15
+ from llava_llama3.utils import disable_torch_init
16
+ from llava_llama3.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
17
  from llava_llama3.serve.cli import chat_llava
18
 
19
+ import requests
20
+ from io import BytesIO
21
+ import base64
22
  import os
23
+ import glob
24
+ import pandas as pd
25
+ from tqdm import tqdm
26
+ import json
27
 
 
28
  root_path = os.path.dirname(os.path.abspath(__file__))
29
  print(f'\033[92m{root_path}\033[0m')
30
  os.environ['GRADIO_TEMP_DIR'] = root_path
31
 
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA")
34
+ parser.add_argument("--device", type=str, default="cuda")
35
+ parser.add_argument("--conv-mode", type=str, default="llama_3")
36
+ parser.add_argument("--temperature", type=float, default=0.7)
37
+ parser.add_argument("--max-new-tokens", type=int, default=512)
38
+ parser.add_argument("--load-8bit", action="store_true")
39
+ parser.add_argument("--load-4bit", action="store_true")
40
+ args = parser.parse_args()
 
41
 
42
+ # Load model
43
  tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
44
+ args.model_path,
45
  None,
46
  'llava_llama3',
47
+ args.load_8bit,
48
+ args.load_4bit,
49
+ device=args.device)
 
50
 
51
+ @spaces.GPU
52
+ def bot_streaming(message, history):
53
+ print(message)
54
  image_file = None
55
  if message["files"]:
56
+ if type(message["files"][-1]) == dict:
57
  image_file = message["files"][-1]["path"]
58
  else:
59
  image_file = message["files"][-1]
60
  else:
61
  for hist in history:
62
+ if type(hist[0]) == tuple:
63
  image_file = hist[0][0]
64
 
65
  if image_file is None:
66
  gr.Error("You need to upload an image for LLaVA to work.")
67
  return
68
 
 
 
 
 
69
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
70
  def generate():
71
  print('\033[92mRunning chat\033[0m')
72
+ output = chat_llava(
73
+ args=args,
74
+ image_file=image_file,
75
+ text=message['text'],
76
+ tokenizer=tokenizer,
77
+ model=llava_model,
78
+ image_processor=image_processor,
79
+ context_len=context_len,
80
+ streamer=streamer)
81
+ return output
82
 
83
  thread = Thread(target=generate)
84
  thread.start()
85
+ # thread.join()
86
 
87
  buffer = ""
88
+ # output = generate()
89
  for new_text in streamer:
90
  buffer += new_text
91
+ generated_text_without_prompt = buffer
92
  time.sleep(0.06)
93
+ yield generated_text_without_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ chatbot = gr.Chatbot(scale=1)
96
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
97
+ with gr.Blocks(fill_height=True) as demo:
98
+ gr.ChatInterface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  fn=bot_streaming,
100
+ title="FinLLaVA Demo",
 
 
101
  examples=[
102
  {"text": "What is in this picture?", "files": ["http://images.cocodataset.org/val2017/000000039769.jpg"]},
103
  ],
 
104
  description="",
105
+ stop_btn="Stop Generation",
106
+ multimodal=True,
107
+ textbox=chat_input,
108
+ chatbot=chatbot,
109
  )
110
 
111
+ demo.queue(api_open=False)
112
+ demo.launch(show_api=False, share=False)