chendl commited on
Commit
fdffde6
β€’
1 Parent(s): 468d028

update chat

Browse files
Files changed (3) hide show
  1. app.py +189 -92
  2. multimodal/open_flamingo/chat/conversation.py +486 -0
  3. temp.py +168 -0
app.py CHANGED
@@ -17,7 +17,7 @@ from PIL import Image
17
  from huggingface_hub import hf_hub_download, login
18
 
19
  from open_flamingo.src.factory import create_model_and_transforms
20
-
21
 
22
  flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
23
  "ViT-L-14",
@@ -49,6 +49,7 @@ if "vision_encoder.logit_scale"in model_state_dict:
49
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
50
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
51
  flamingo.load_state_dict(model_state_dict, strict=True)
 
52
 
53
  def get_outputs(
54
  model,
@@ -176,106 +177,202 @@ def generate(
176
  return (f"Output:{gen_text}")
177
 
178
 
179
- with gr.Blocks() as demo:
180
- gr.Markdown(
181
- """
182
- 🍜 Object Centric Pretraining Demo
183
- In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience.
184
- The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text.
185
- """
186
- )
187
-
188
- with gr.Accordion("See terms and conditions"):
189
- gr.Markdown(
190
- """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""")
191
-
192
- with gr.Tab("πŸ“· Image Captioning"):
193
- with gr.Row():
194
-
195
-
196
- query_image = gr.Image(type="pil")
197
- with gr.Row():
198
- chat_input = gr.Textbox(lines=1, label="Chat Input")
199
- text_output = gr.Textbox(value="Output:", label="Model output")
200
-
201
- run_btn = gr.Button("Run model")
202
-
203
-
204
-
205
- def on_click_fn(img,text): return generate(0, img, text)
206
-
207
- run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
208
-
209
- with gr.Tab("πŸ¦“ Grounding"):
210
- with gr.Row():
211
- with gr.Column(scale=1):
212
- query_image = gr.Image(type="pil")
213
- with gr.Column(scale=1):
214
- out_image = gr.Image(type="pil")
215
- with gr.Row():
216
- chat_input = gr.Textbox(lines=1, label="Chat Input")
217
- text_output = gr.Textbox(value="Output:", label="Model output")
218
-
219
- run_btn = gr.Button("Run model")
220
-
221
 
222
- def on_click_fn(img, text): return generate(1, img, text)
223
 
 
 
 
224
 
225
- run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image])
 
 
 
 
 
 
 
226
 
227
- with gr.Tab("πŸ”’ Counting objects"):
228
- with gr.Row():
229
- query_image = gr.Image(type="pil")
230
- with gr.Row():
231
- chat_input = gr.Textbox(lines=1, label="Chat Input")
232
- text_output = gr.Textbox(value="Output:", label="Model output")
233
 
234
- run_btn = gr.Button("Run model")
 
 
 
 
 
 
 
235
 
236
 
237
- def on_click_fn(img,text): return generate(0, img, text)
 
 
238
 
 
 
 
239
 
240
- run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output])
241
 
242
- with gr.Tab("πŸ•΅οΈ Visual Question Answering"):
243
- with gr.Row():
244
- query_image = gr.Image(type="pil")
245
- with gr.Row():
246
- question = gr.Textbox(lines=1, label="Question")
247
- text_output = gr.Textbox(value="Output:", label="Model output")
 
248
 
249
- run_btn = gr.Button("Run model")
250
 
251
-
252
- def on_click_fn(img, txt): return generate(2, img, txt)
253
-
254
-
255
- run_btn.click(
256
- on_click_fn, inputs=[query_image, question], outputs=[text_output]
257
- )
258
-
259
- with gr.Tab("🌎 Custom"):
260
- gr.Markdown(
261
- """### Customize the demonstration by uploading your own images and text samples.
262
- ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**"""
263
- )
264
- with gr.Row():
265
- query_image = gr.Image(type="pil")
266
- with gr.Row():
267
- question = gr.Textbox(lines=1, label="Question")
268
- text_output = gr.Textbox(value="Output:", label="Model output")
269
-
270
- run_btn = gr.Button("Run model")
271
-
272
-
273
- def on_click_fn(img, txt): return generate(2, img, txt)
274
-
275
-
276
- run_btn.click(
277
- on_click_fn, inputs=[query_image, question], outputs=[text_output]
278
- )
279
-
280
- demo.queue(concurrency_count=1)
281
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from huggingface_hub import hf_hub_download, login
18
 
19
  from open_flamingo.src.factory import create_model_and_transforms
20
+ from open_flamingo.chat.conversation import Chat, CONV_VISION
21
 
22
  flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
23
  "ViT-L-14",
 
49
  del model_state_dict["vision_encoder.visual.ln_post.weight"]
50
  del model_state_dict["vision_encoder.visual.ln_post.bias"]
51
  flamingo.load_state_dict(model_state_dict, strict=True)
52
+ chat = Chat(flamingo, image_processor, tokenizer, vis_embed_size )
53
 
54
  def get_outputs(
55
  model,
 
177
  return (f"Output:{gen_text}")
178
 
179
 
180
+ title = """<h1 align="center">Demo of Compositional-VLM</h1>"""
181
+ description = """<h3>This is the demo of Compositional-VLM. Upload your images and start chatting!</h3>"""
182
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://compositionalvlm.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
183
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ # TODO show examples below
186
 
187
+ # ========================================
188
+ # Gradio Setting
189
+ # ========================================
190
 
191
+ def gradio_reset(chat_state, img_list):
192
+ if chat_state is not None:
193
+ chat_state = []
194
+ if img_list is not None:
195
+ img_list = []
196
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
197
+ interactive=False), gr.update(
198
+ value="Upload & Start Chat", interactive=True), chat_state, img_list
199
 
 
 
 
 
 
 
200
 
201
+ def upload_img(gr_img, text_input, chat_state):
202
+ if gr_img is None:
203
+ return None, None, gr.update(interactive=True), chat_state, None
204
+ chat_state = []
205
+ img_list = []
206
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
207
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
208
+ value="Start Chatting", interactive=False), chat_state, img_list
209
 
210
 
211
+ def gradio_ask(user_message, chatbot, chat_state):
212
+ if len(user_message) == 0:
213
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
214
 
215
+ chat.ask(user_message, chat_state)
216
+ chatbot = chatbot + [[user_message, None]]
217
+ return '', chatbot, chat_state
218
 
 
219
 
220
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
221
+ llm_message = \
222
+ chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
223
+ max_length=2000)[0]
224
+
225
+ chatbot[-1][1] = llm_message
226
+ return chatbot, chat_state, img_list
227
 
 
228
 
229
+ with gr.Blocks() as demo:
230
+ gr.Markdown(title)
231
+ gr.Markdown(SHARED_UI_WARNING)
232
+ gr.Markdown(description)
233
+ gr.Markdown(article)
234
+
235
+ with gr.Row():
236
+ with gr.Column(scale=0.5):
237
+ image = gr.Image(type="pil")
238
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
239
+ clear = gr.Button("Restart")
240
+
241
+ num_beams = gr.Slider(
242
+ minimum=1,
243
+ maximum=5,
244
+ value=1,
245
+ step=1,
246
+ interactive=True,
247
+ label="beam search numbers)",
248
+ )
249
+
250
+ temperature = gr.Slider(
251
+ minimum=0.1,
252
+ maximum=2.0,
253
+ value=1.0,
254
+ step=0.1,
255
+ interactive=True,
256
+ label="Temperature",
257
+ )
258
+
259
+ with gr.Column():
260
+ chat_state = gr.State()
261
+ img_list = gr.State()
262
+ chatbot = gr.Chatbot(label='Compositional-VLM')
263
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
264
+
265
+ upload_button.click(upload_img, [image, text_input, chat_state],
266
+ [image, text_input, upload_button, chat_state, img_list])
267
+
268
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
269
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
270
+ )
271
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
272
+ queue=False)
273
+
274
+ demo.launch(enable_queue=True)
275
+ #
276
+ # with gr.Blocks() as demo:
277
+ # gr.Markdown(
278
+ # """
279
+ # 🍜 Object Centric Pretraining Demo
280
+ # In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience.
281
+ # The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text.
282
+ # """
283
+ # )
284
+ #
285
+ # with gr.Accordion("See terms and conditions"):
286
+ # gr.Markdown(
287
+ # """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""")
288
+ #
289
+ # with gr.Tab("πŸ“· Image Captioning"):
290
+ # with gr.Row():
291
+ #
292
+ #
293
+ # query_image = gr.Image(type="pil")
294
+ # with gr.Row():
295
+ # chat_input = gr.Textbox(lines=1, label="Chat Input")
296
+ # text_output = gr.Textbox(value="Output:", label="Model output")
297
+ #
298
+ # run_btn = gr.Button("Run model")
299
+ #
300
+ #
301
+ #
302
+ # def on_click_fn(img,text): return generate(0, img, text)
303
+ #
304
+ # run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
305
+ #
306
+ # with gr.Tab("πŸ¦“ Grounding"):
307
+ # with gr.Row():
308
+ # with gr.Column(scale=1):
309
+ # query_image = gr.Image(type="pil")
310
+ # with gr.Column(scale=1):
311
+ # out_image = gr.Image(type="pil")
312
+ # with gr.Row():
313
+ # chat_input = gr.Textbox(lines=1, label="Chat Input")
314
+ # text_output = gr.Textbox(value="Output:", label="Model output")
315
+ #
316
+ # run_btn = gr.Button("Run model")
317
+ #
318
+ #
319
+ # def on_click_fn(img, text): return generate(1, img, text)
320
+ #
321
+ #
322
+ # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image])
323
+ #
324
+ # with gr.Tab("πŸ”’ Counting objects"):
325
+ # with gr.Row():
326
+ # query_image = gr.Image(type="pil")
327
+ # with gr.Row():
328
+ # chat_input = gr.Textbox(lines=1, label="Chat Input")
329
+ # text_output = gr.Textbox(value="Output:", label="Model output")
330
+ #
331
+ # run_btn = gr.Button("Run model")
332
+ #
333
+ #
334
+ # def on_click_fn(img,text): return generate(0, img, text)
335
+ #
336
+ #
337
+ # run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output])
338
+ #
339
+ # with gr.Tab("πŸ•΅οΈ Visual Question Answering"):
340
+ # with gr.Row():
341
+ # query_image = gr.Image(type="pil")
342
+ # with gr.Row():
343
+ # question = gr.Textbox(lines=1, label="Question")
344
+ # text_output = gr.Textbox(value="Output:", label="Model output")
345
+ #
346
+ # run_btn = gr.Button("Run model")
347
+ #
348
+ #
349
+ # def on_click_fn(img, txt): return generate(2, img, txt)
350
+ #
351
+ #
352
+ # run_btn.click(
353
+ # on_click_fn, inputs=[query_image, question], outputs=[text_output]
354
+ # )
355
+ #
356
+ # with gr.Tab("🌎 Custom"):
357
+ # gr.Markdown(
358
+ # """### Customize the demonstration by uploading your own images and text samples.
359
+ # ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**"""
360
+ # )
361
+ # with gr.Row():
362
+ # query_image = gr.Image(type="pil")
363
+ # with gr.Row():
364
+ # question = gr.Textbox(lines=1, label="Question")
365
+ # text_output = gr.Textbox(value="Output:", label="Model output")
366
+ #
367
+ # run_btn = gr.Button("Run model")
368
+ #
369
+ #
370
+ # def on_click_fn(img, txt): return generate(2, img, txt)
371
+ #
372
+ #
373
+ # run_btn.click(
374
+ # on_click_fn, inputs=[query_image, question], outputs=[text_output]
375
+ # )
376
+ #
377
+ # demo.queue(concurrency_count=1)
378
+ # demo.launch()
multimodal/open_flamingo/chat/conversation.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import numpy as np
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
8
+ from transformers import StoppingCriteria, StoppingCriteriaList
9
+
10
+ import dataclasses
11
+ from enum import auto, Enum
12
+ from typing import List, Tuple, Any
13
+
14
+ import string
15
+ import cv2
16
+ import gradio as gr
17
+
18
+ from huggingface_hub import hf_hub_download, login
19
+
20
+ from open_flamingo.src.factory import create_model_and_transforms
21
+
22
+ class SeparatorStyle(Enum):
23
+ """Different separator style."""
24
+ SINGLE = auto()
25
+ TWO = auto()
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class Conversation:
30
+ """A class that keeps all conversation history."""
31
+ system: str
32
+ roles: List[str]
33
+ messages: List[List[str]]
34
+ offset: int
35
+ # system_img: List[Image.Image] = []
36
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
37
+ sep: str = "###"
38
+ sep2: str = None
39
+
40
+ skip_next: bool = False
41
+ conv_id: Any = None
42
+
43
+ def get_prompt(self):
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in self.messages:
47
+ if message:
48
+ ret += role + ": " + message + self.sep
49
+ else:
50
+ ret += role + ":"
51
+ return ret
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(self.messages):
56
+ if message:
57
+ ret += role + ": " + message + seps[i % 2]
58
+ else:
59
+ ret += role + ":"
60
+ return ret
61
+ else:
62
+ raise ValueError(f"Invalid style: {self.sep_style}")
63
+
64
+ def append_message(self, role, message):
65
+ self.messages.append([role, message])
66
+
67
+ def to_gradio_chatbot(self):
68
+ ret = []
69
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
70
+ if i % 2 == 0:
71
+ ret.append([msg, None])
72
+ else:
73
+ ret[-1][-1] = msg
74
+ return ret
75
+
76
+ def copy(self):
77
+ return Conversation(
78
+ system=self.system,
79
+ # system_img=self.system_img,
80
+ roles=self.roles,
81
+ messages=[[x, y] for x, y in self.messages],
82
+ offset=self.offset,
83
+ sep_style=self.sep_style,
84
+ sep=self.sep,
85
+ sep2=self.sep2,
86
+ conv_id=self.conv_id)
87
+
88
+ def dict(self):
89
+ return {
90
+ "system": self.system,
91
+ # "system_img": self.system_img,
92
+ "roles": self.roles,
93
+ "messages": self.messages,
94
+ "offset": self.offset,
95
+ "sep": self.sep,
96
+ "sep2": self.sep2,
97
+ "conv_id": self.conv_id,
98
+ }
99
+
100
+
101
+ class StoppingCriteriaSub(StoppingCriteria):
102
+
103
+ def __init__(self, stops=[], encounters=1):
104
+ super().__init__()
105
+ self.stops = stops
106
+
107
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
108
+ for stop in self.stops:
109
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
110
+ return True
111
+
112
+ return False
113
+
114
+
115
+ CONV_VISION = Conversation(
116
+ system="Give the following image: <Img>ImageContent</Img>. "
117
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
118
+ roles=("Human", "Assistant"),
119
+ messages=[],
120
+ offset=2,
121
+ sep_style=SeparatorStyle.SINGLE,
122
+ sep="###",
123
+ )
124
+
125
+ def get_outputs(
126
+ model,
127
+ batch_images,
128
+ attention_mask,
129
+ max_generation_length,
130
+ min_generation_length,
131
+ num_beams,
132
+ length_penalty,
133
+ input_ids,
134
+ image_start_index_list=None,
135
+ image_nums=None,
136
+ bad_words_ids=None,
137
+ ):
138
+ # and torch.cuda.amp.autocast(dtype=torch.float16)
139
+ with torch.inference_mode():
140
+ outputs = model(
141
+ vision_x=batch_images,
142
+ lang_x=input_ids,
143
+ attention_mask=attention_mask,
144
+ labels=None,
145
+ image_nums=image_nums,
146
+ image_start_index_list=image_start_index_list,
147
+ added_bbox_list=None,
148
+ add_box=False,
149
+ )
150
+ # outputs = model.generate(
151
+ # batch_images,
152
+ # input_ids,
153
+ # attention_mask=attention_mask,
154
+ # max_new_tokens=max_generation_length,
155
+ # min_length=min_generation_length,
156
+ # num_beams=num_beams,
157
+ # length_penalty=length_penalty,
158
+ # image_start_index_list=image_start_index_list,
159
+ # image_nums=image_nums,
160
+ # bad_words_ids=bad_words_ids,
161
+ # )
162
+
163
+ return outputs
164
+
165
+ def generate(
166
+ idx,
167
+ image,
168
+ text,
169
+ image_processor,
170
+ tokenizer,
171
+ flamingo,
172
+ vis_embed_size=256,
173
+ rank=0,
174
+ world_size=1,
175
+ ):
176
+ if image is None:
177
+ raise gr.Error("Please upload an image.")
178
+ flamingo.eval()
179
+ loc_token_ids = []
180
+ for i in range(1000):
181
+ loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
182
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
183
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
184
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
185
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
186
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
187
+
188
+ image_ori = image
189
+ image = image.convert("RGB")
190
+ width = image.width
191
+ height = image.height
192
+ image = image.resize((224, 224))
193
+ batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
194
+ if idx == 1:
195
+ prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|> {text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
196
+ bad_words_ids = None
197
+ max_generation_length = 5
198
+ else:
199
+ prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
200
+ bad_words_ids = loc_word_ids
201
+ max_generation_length = 300
202
+ encodings = tokenizer(
203
+ prompt,
204
+ padding="longest",
205
+ truncation=True,
206
+ return_tensors="pt",
207
+ max_length=2000,
208
+ )
209
+ input_ids = encodings["input_ids"]
210
+ attention_mask = encodings["attention_mask"]
211
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
212
+ image_start_index_list = [[x] for x in image_start_index_list]
213
+ image_nums = [1] * len(input_ids)
214
+ outputs = get_outputs(
215
+ model=flamingo,
216
+ batch_images=batch_images,
217
+ attention_mask=attention_mask,
218
+ max_generation_length=max_generation_length,
219
+ min_generation_length=4,
220
+ num_beams=1,
221
+ length_penalty=1.0,
222
+ input_ids=input_ids,
223
+ bad_words_ids=bad_words_ids,
224
+ image_start_index_list=image_start_index_list,
225
+ image_nums=image_nums,
226
+ )
227
+
228
+ boxes = outputs["boxes"]
229
+ scores = outputs["scores"]
230
+ if len(scores) > 0:
231
+ box = boxes[scores.argmax()]/224
232
+ print(f"{box}")
233
+
234
+
235
+ if len(boxes)>0:
236
+ open_cv_image = np.array(image_ori)
237
+ # Convert RGB to BGR
238
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
239
+ box = box*[width,height,width,height]
240
+ # for box in boxes:
241
+ open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
242
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
243
+ return f"Output:{box}", out_image
244
+ else:
245
+ gen_text = tokenizer.batch_decode(outputs)
246
+ return (f"{gen_text}")
247
+
248
+ def preprocess_conv(data):
249
+ conversation = ""
250
+ BEGIN_SIGNAL = "### "
251
+ END_SIGNAL = "\n"
252
+ for idx, d in enumerate(data):
253
+ from_str = d["from"]
254
+ if from_str.lower() == "human":
255
+ from_str = "Human"
256
+ elif from_str.lower() == "gpt":
257
+ from_str = "Assistant"
258
+ else:
259
+ from_str = 'unknown'
260
+ conversation += (BEGIN_SIGNAL + from_str + ": " + d["value"] + END_SIGNAL)
261
+ return conversation
262
+
263
+ class Chat:
264
+ def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
265
+ self.device = device
266
+ self.model = model
267
+ self.vis_processor = vis_processor
268
+ self.tokenizer = tokenizer
269
+ self.vis_embed_size = vis_embed_size
270
+ self.conv = []
271
+ # stop_words_ids = [torch.tensor([835]).to(self.device),
272
+ # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
273
+ # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
274
+
275
+ def ask(self, text, conv):
276
+ conv.append(({
277
+ "from": "human",
278
+ "value": text,
279
+ }))
280
+ # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
281
+ # and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
282
+ # conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
283
+ # else:
284
+ # conv.append_message(conv.roles[0], text)
285
+
286
+ def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
287
+ repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
288
+ # conv.append_message(conv.roles[1], None)
289
+ # embs = self.get_context_emb(conv, img_list)
290
+ #
291
+ # # current_max_len = embs.shape[1] + max_new_tokens + 100
292
+ # # begin_idx = max(0, current_max_len - max_length)
293
+ # # embs = embs[:, begin_idx:]
294
+ # outputs = self.model.llama_model.generate(
295
+ # inputs_embeds=embs,
296
+ # max_new_tokens=max_new_tokens,
297
+ # stopping_criteria=self.stopping_criteria,
298
+ # num_beams=num_beams,
299
+ # min_length=min_length,
300
+ # top_p=top_p,
301
+ # repetition_penalty=repetition_penalty,
302
+ # length_penalty=length_penalty,
303
+ # temperature=temperature,
304
+ # )
305
+ # output_token = outputs[0]
306
+ # if output_token[0] == 0:
307
+ # output_token = output_token[1:]
308
+ # output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
309
+ # output_text = output_text.split('###')[0] # remove the stop sign '###'
310
+ # output_text = output_text.split('Assistant:')[-1].strip()
311
+ # conv.messages[-1][1] = output_text
312
+
313
+ media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
314
+ box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
315
+ endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
316
+ endofattr_token_id = self.tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
317
+ endofmedia_token_id = self.tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
318
+ visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
319
+ previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
320
+ prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
321
+ size = self.vis_processor.size["shortest_edge"]
322
+ model.eval()
323
+ # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
324
+ image_path = input("Please enter the image path: ")
325
+ image = Image.open(image_path).convert("RGB")
326
+ image = image.resize((size, size))
327
+ print(f"image size: {image.size}")
328
+ batch_images = preprocess_image(img_list[0], self.vis_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0).to("cuda")
329
+ # conversation = []
330
+ human_sentence = None
331
+ conv.append({
332
+ "from": "gpt",
333
+ "value": "",
334
+ })
335
+ # while True:
336
+ # human_sentence = input("### Human: ")
337
+ # if human_sentence == "#end#":
338
+ # break
339
+ # conversation.append({
340
+ # "from": "human",
341
+ # "value": human_sentence,
342
+ # })
343
+ # conversation.append({
344
+ # "from": "gpt",
345
+ # "value": "",
346
+ # })
347
+ text = preprocess_conv(conv).strip()
348
+ caption = f"<|#image#|>{tokenizer.pad_token * self.vis_embed_size}<|#endofimage#|>{text}"
349
+ encodings = tokenizer(
350
+ caption,
351
+ padding="longest",
352
+ truncation=True,
353
+ return_tensors="pt",
354
+ max_length=2000,
355
+ )
356
+ input_ids = encodings["input_ids"].to("cuda")
357
+ attention_mask = encodings["attention_mask"].to("cuda")
358
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
359
+ image_start_index_list = [[x] for x in image_start_index_list]
360
+ image_nums = [1] * len(input_ids)
361
+ with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
362
+ outputs = model.generate(
363
+ batch_images,
364
+ input_ids,
365
+ attention_mask=attention_mask,
366
+ max_new_tokens=100,
367
+ # min_new_tokens=8,
368
+ num_beams=1,
369
+ image_start_index_list=image_start_index_list,
370
+ image_nums=image_nums,
371
+ )
372
+ output_token = outputs[0, input_ids.shape[1]:]
373
+ output_text = tokenizer.decode(output_token, skip_special_tokens=True).strip()
374
+ conv[-1]["value"] = output_text
375
+ # conv.messages[-1][1] = output_text
376
+ print(
377
+ f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
378
+
379
+ return output_text, output_token.cpu().numpy()
380
+
381
+ def upload_img(self, image, conv, img_list):
382
+ img_list.append(image)
383
+ # if isinstance(image, str): # is a image path
384
+ # raw_image = Image.open(image).convert('RGB')
385
+ # image = image.resize((224, 224))
386
+ # image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
387
+ # elif isinstance(image, Image.Image):
388
+ # raw_image = image
389
+ # image = image.resize((224, 224))
390
+ # image = self.vis_processor(raw_image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
391
+ # elif isinstance(image, torch.Tensor):
392
+ # if len(image.shape) == 3:
393
+ # image = image.unsqueeze(0)
394
+ # # image = image.to(self.device)
395
+ #
396
+ # # image_emb, _ = self.model.encode_img(image)
397
+ # img_list.append(image_emb)
398
+ conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
399
+ msg = "Received."
400
+ # self.conv.append_message(self.conv.roles[1], msg)
401
+ return msg
402
+
403
+ def get_context_emb(self, conv, img_list):
404
+ prompt = conv.get_prompt()
405
+ prompt_segs = prompt.split('<ImageHere>')
406
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
407
+ seg_tokens = [
408
+ self.model.llama_tokenizer(
409
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
410
+ # only add bos to the first seg
411
+ for i, seg in enumerate(prompt_segs)
412
+ ]
413
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
414
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
415
+ mixed_embs = torch.cat(mixed_embs, dim=1)
416
+ return mixed_embs
417
+
418
+ def evaluate_exp(
419
+ model,
420
+ tokenizer,
421
+ image_processor,
422
+ vis_embed_size=None,
423
+ rank=0,
424
+ world_size=1,
425
+ id=0,
426
+ add_visual=True,
427
+ ):
428
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
429
+ box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
430
+ endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
431
+ endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
432
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
433
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
434
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
435
+ prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
436
+ size = image_processor.size["shortest_edge"]
437
+ model.eval()
438
+ # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
439
+ image_path = input("Please enter the image path: ")
440
+ image = Image.open(image_path).convert("RGB")
441
+ image = image.resize((size, size))
442
+ print(f"image size: {image.size}")
443
+ batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0).to("cuda")
444
+ conversation = []
445
+ human_sentence = None
446
+ while True:
447
+ human_sentence = input("### Human: ")
448
+ if human_sentence == "#end#":
449
+ break
450
+ conversation.append({
451
+ "from": "human",
452
+ "value": human_sentence,
453
+ })
454
+ conversation.append({
455
+ "from": "gpt",
456
+ "value": "",
457
+ })
458
+ text = preprocess_conv(conversation).strip()
459
+ caption = f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}"
460
+ encodings = tokenizer(
461
+ caption,
462
+ padding="longest",
463
+ truncation=True,
464
+ return_tensors="pt",
465
+ max_length=2000,
466
+ )
467
+ input_ids = encodings["input_ids"].to("cuda")
468
+ attention_mask = encodings["attention_mask"].to("cuda")
469
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
470
+ image_start_index_list = [[x] for x in image_start_index_list]
471
+ image_nums = [1] * len(input_ids)
472
+ with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
473
+ outputs = model.generate(
474
+ batch_images,
475
+ input_ids,
476
+ attention_mask=attention_mask,
477
+ max_new_tokens=100,
478
+ # min_new_tokens=8,
479
+ num_beams=1,
480
+ image_start_index_list=image_start_index_list,
481
+ image_nums=image_nums,
482
+ )
483
+ print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
484
+
485
+
486
+
temp.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import gradio as gr
9
+
10
+ from minigpt4.common.config import Config
11
+ from minigpt4.common.dist_utils import get_rank
12
+ from minigpt4.common.registry import registry
13
+ from minigpt4.conversation.conversation import Chat, CONV_VISION
14
+
15
+ # imports modules for registration
16
+ from minigpt4.datasets.builders import *
17
+ from minigpt4.models import *
18
+ from minigpt4.processors import *
19
+ from minigpt4.runners import *
20
+ from minigpt4.tasks import *
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Demo")
25
+ parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml',
26
+ help="path to configuration file.")
27
+ parser.add_argument(
28
+ "--options",
29
+ nargs="+",
30
+ help="override some settings in the used config, the key-value pair "
31
+ "in xxx=yyy format will be merged into config file (deprecate), "
32
+ "change to --cfg-options instead.",
33
+ )
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def setup_seeds(config):
39
+ seed = config.run_cfg.seed + get_rank()
40
+
41
+ random.seed(seed)
42
+ np.random.seed(seed)
43
+ torch.manual_seed(seed)
44
+
45
+ cudnn.benchmark = False
46
+ cudnn.deterministic = True
47
+
48
+
49
+ # ========================================
50
+ # Model Initialization
51
+ # ========================================
52
+
53
+ SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
54
+ You can duplicate and use it with a paid private GPU.
55
+ <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
56
+ Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
57
+ '''
58
+
59
+ print('Initializing Chat')
60
+ cfg = Config(parse_args())
61
+
62
+ model_config = cfg.model_cfg
63
+ model_cls = registry.get_model_class(model_config.arch)
64
+ model = model_cls.from_config(model_config).to('cuda:0')
65
+
66
+ vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
67
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
+ chat = Chat(model, vis_processor)
69
+ print('Initialization Finished')
70
+
71
+
72
+ # ========================================
73
+ # Gradio Setting
74
+ # ========================================
75
+
76
+ def gradio_reset(chat_state, img_list):
77
+ if chat_state is not None:
78
+ chat_state.messages = []
79
+ if img_list is not None:
80
+ img_list = []
81
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
82
+ interactive=False), gr.update(
83
+ value="Upload & Start Chat", interactive=True), chat_state, img_list
84
+
85
+
86
+ def upload_img(gr_img, text_input, chat_state):
87
+ if gr_img is None:
88
+ return None, None, gr.update(interactive=True), chat_state, None
89
+ chat_state = CONV_VISION.copy()
90
+ img_list = []
91
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
92
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
93
+ value="Start Chatting", interactive=False), chat_state, img_list
94
+
95
+ def ask(text, conv):
96
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
97
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
98
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
99
+ else:
100
+ conv.append_message(conv.roles[0], text)
101
+
102
+ def gradio_ask(user_message, chatbot, chat_state):
103
+ if len(user_message) == 0:
104
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
105
+ chat.ask(user_message, chat_state)
106
+ chatbot = chatbot + [[user_message, None]]
107
+ return '', chatbot, chat_state
108
+
109
+
110
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
111
+ llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
112
+ chatbot[-1][1] = llm_message
113
+ return chatbot, chat_state, img_list
114
+
115
+
116
+ title = """<h1 align="center">Demo of Compositional-VLM</h1>"""
117
+ description = """<h3>This is the demo of Compositional-VLM. Upload your images and start chatting!</h3>"""
118
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://compositionalvlm.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
119
+ """
120
+
121
+ # TODO show examples below
122
+
123
+ with gr.Blocks() as demo:
124
+ gr.Markdown(title)
125
+ gr.Markdown(SHARED_UI_WARNING)
126
+ gr.Markdown(description)
127
+ gr.Markdown(article)
128
+
129
+ with gr.Row():
130
+ with gr.Column(scale=0.5):
131
+ image = gr.Image(type="pil")
132
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
133
+ clear = gr.Button("Restart")
134
+
135
+ num_beams = gr.Slider(
136
+ minimum=1,
137
+ maximum=5,
138
+ value=1,
139
+ step=1,
140
+ interactive=True,
141
+ label="beam search numbers)",
142
+ )
143
+
144
+ temperature = gr.Slider(
145
+ minimum=0.1,
146
+ maximum=2.0,
147
+ value=1.0,
148
+ step=0.1,
149
+ interactive=True,
150
+ label="Temperature",
151
+ )
152
+
153
+ with gr.Column():
154
+ chat_state = gr.State()
155
+ img_list = gr.State()
156
+ chatbot = gr.Chatbot(label='Compositional-VLM')
157
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
158
+
159
+ upload_button.click(upload_img, [image, text_input, chat_state],
160
+ [image, text_input, upload_button, chat_state, img_list])
161
+
162
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
163
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
164
+ )
165
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
166
+ queue=False)
167
+
168
+ demo.launch(enable_queue=True)